{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "toc-hr-collapsed": false
   },
   "source": [
    "# Pytorch 快速使用手册"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.2.0+cpu\n"
     ]
    }
   ],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] =\"2\"\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "print(torch.__version__)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 查看是否有GPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\") ##判断是否有gpu\n",
    "import torch.backends.cudnn as cudnn\n",
    "if torch.cuda.is_available():\n",
    "    cudnn.benchmark = True\n",
    "\n",
    "    \n",
    "def seed_everything(seed):\n",
    "    random.seed(seed)\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = True\n",
    "    torch.backends.cudnn.enabled = True\n",
    "    #tf.set_random_seed(seed)\n",
    "seed_everything(2019)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "toc-hr-collapsed": false
   },
   "source": [
    "## 基本命令"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 3, 224, 224]) torch.FloatTensor\n",
      "torch.Size([2, 3, 224, 224]) torch.FloatTensor\n",
      "torch.Size([2, 3, 224, 224])\n",
      "tensor([ 0.1459, -2.1029])\n"
     ]
    }
   ],
   "source": [
    "x = torch.randn(2,3,224,224)\n",
    "print(x.size(),x.type())\n",
    "y = torch.randn(2,3,224,224)\n",
    "z = x+y\n",
    "\n",
    "print(z.size(),z.type())\n",
    "print(y.add_(x).size())\n",
    "print(y[1,1,1,:2])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### torch.Tensor与numpy相互转化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "a: torch.Size([5]) <class 'torch.Tensor'>\n",
      "b: (5,) <class 'numpy.ndarray'>\n",
      "b <class 'numpy.ndarray'>\n",
      "e <class 'torch.Tensor'>\n"
     ]
    }
   ],
   "source": [
    "# torch.Tensor ->numpy\n",
    "a = torch.ones(5)\n",
    "print(\"a:\",a.size(),type(a))\n",
    "b = a.numpy()\n",
    "print('b:',b.shape,type(b))\n",
    "\n",
    "# numpy ->torch.Tensor\n",
    "e=torch.from_numpy(b)\n",
    "print('b',type(b))\n",
    "print('e',type(e))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### torch.Tensor与cuda相互转化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.Tensor ->cuda \n",
    "x = torch.randn(2,3,224,224)\n",
    "y = torch.randn(2,3,224,224)\n",
    "if torch.cuda.is_available():\n",
    "    x = x.cuda()\n",
    "    x = x.to(device)\n",
    "    y = y.cuda()\n",
    "    x + y\n",
    "\n",
    "# cuda->torch.Tensor \n",
    "b = x.cpu()\n",
    "c = y.cpu()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### torch.Tensor与Variable相互转化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "y: <class 'torch.Tensor'> torch.Size([2, 3, 224, 224]) torch.FloatTensor\n",
      "c: torch.FloatTensor torch.Size([2, 3, 224, 224])\n"
     ]
    }
   ],
   "source": [
    "# torch.Tensor ->Variable\n",
    "from torch.autograd import Variable\n",
    "x = torch.randn(2,3,224,224)\n",
    "y = Variable(x)\n",
    "print('y:',type(y),y.size(),y.type())\n",
    "\n",
    "# Variable -> torch.Tensor\n",
    "c=y.data#通过 Variable.data 方法相当于将Variable中的torch.tensor 取出来\n",
    "print('c:',c.type(),c.shape)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### numpy与list相互转化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "f1: <class 'list'>\n",
      "g: <class 'numpy.ndarray'>\n"
     ]
    }
   ],
   "source": [
    "#  numpy -> list\n",
    "import numpy as np\n",
    "d = np.random.random((2,4))\n",
    "f1=d.tolist()\n",
    "f2=list(d)\n",
    "\n",
    "# list -> numpy\n",
    "g=np.asarray(f2)\n",
    "g=np.array(f2)\n",
    "print('f1:',type(f1))\n",
    "print('g:',type(g))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### long,int,double,float,byte类型转化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.FloatTensor\n",
      "float: torch.FloatTensor\n",
      "long: torch.LongTensor\n",
      "int: torch.IntTensor\n",
      "double: torch.DoubleTensor\n",
      "byte: torch.ByteTensor\n"
     ]
    }
   ],
   "source": [
    "\n",
    "x = torch.randn(2,3,32,32)\n",
    "print(x.type())\n",
    "y = x.float()\n",
    "print(\"float:\",y.type())\n",
    "y = x.long()\n",
    "print(\"long:\",y.type())\n",
    "y = x.int()\n",
    "print(\"int:\",y.type())\n",
    "y = x.double()\n",
    "print(\"double:\",y.type())\n",
    "y = x.byte()\n",
    "print(\"byte:\",y.type())\n",
    "\n",
    "# # 一般只要在Tensor后加long(), int(), double(),float(),byte()等函数就能将Tensor进行类型转换；\n",
    "# 例如：Torch.LongTensor--->Torch.FloatTensor, 直接使用data.float()即可"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Autograd: 自动求导"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x = tensor([[1., 1.],\n",
      "        [1., 1.]], requires_grad=True)\n",
      "x + 2 = tensor([[3., 3.],\n",
      "        [3., 3.]], grad_fn=<AddBackward0>)\n",
      "(x + 2) grad_fn <AddBackward0 object at 0x7fe24c70eb70>\n",
      "x grad_fn None\n"
     ]
    }
   ],
   "source": [
    "x = Variable(torch.ones(2, 2), requires_grad=True)\n",
    "print(\"x =\", x)\n",
    "y = x + 2\n",
    "print(\"x + 2 =\", y)\n",
    "print(\"(x + 2) grad_fn\", y.grad_fn)\n",
    "print(\"x grad_fn\",x.grad_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[27., 27.],\n",
      "        [27., 27.]], grad_fn=<MulBackward0>) tensor(27., grad_fn=<MeanBackward0>)\n",
      "x grad= tensor([[4.5000, 4.5000],\n",
      "        [4.5000, 4.5000]])\n"
     ]
    }
   ],
   "source": [
    "z = y * y * 3\n",
    "out = z.mean()\n",
    "\n",
    "print(z, out)\n",
    "out.backward() ## out最好是标量\n",
    "## out = z\n",
    "## out.backward(torch.ones(out.shape))\n",
    "print(\"x grad=\", x.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### [contiguous](https://zhuanlan.zhihu.com/p/64551412)\n",
    "\n",
    "contiguous 本身是形容词，表示连续的，关于 contiguous,\n",
    "PyTorch 提供了is_contiguous、contiguous(形容词动用)两个方法,\n",
    "分别用于判定Tensor是否是 contiguous 的，以及保证Tensor是contiguous的.\n",
    "\n",
    "is_contiguous直观的解释是**Tensor底层一维数组元素的存储顺序与Tensor按行优先一维展开的元素顺序是否一致**。\n",
    "\n",
    "Tensor多维数组底层实现是使用一块连续内存的1维数组（行优先顺序存储).\n",
    "(如transpose、permute、narrow、expand)与原Tensor是共享内存中的数据，\n",
    "不会改变底层数组的存储，但原来在语义上相邻、内存里也相邻的元素在执行这样的操作后，\n",
    "在语义上相邻，但在内存不相邻，即不连续了(is not contiguous).\n",
    "\n",
    "* torch.view等方法操作需要连续的Tensor.\n",
    "\n",
    "transpose,permute 操作虽然没有修改底层一维数组，\n",
    "但是新建了一份Tensor元信息，并在新的元信息中的 重新指定 \n",
    "stride。torch.view 方法约定了不修改数组本身，只是使用新的\n",
    "形状查看数据。如果我们在 transpose、permute 操作后执行 view."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = torch.arange(12).reshape(3,4)\n",
    "print(\"t:\",t)\n",
    "print(\"t.stride:\", t.stride())\n",
    "t2 = t.transpose(0,1)\n",
    "print(\"t.transpose:\",t2)\n",
    "print(\"t.transpose.stride:\", t2.stride())\n",
    "assert t.data_ptr() == t2.data_ptr() # 底层数据是同一个一维数组\n",
    "t.is_contiguous(),t2.is_contiguous() # t连续，t2不连续"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "toc-hr-collapsed": false
   },
   "source": [
    "## 搭建数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Segmentation(object):\n",
    "    def __init__(self, root, mode=None,shuffle=True, transform=None):\n",
    "        super(Segmentation, self).__init__()\n",
    "\n",
    "        self.images, self.masks = _get_pairs(root, mode,shuffle)\n",
    "        self.mode = mode\n",
    "        self.transform = transform\n",
    "        self.shuffle = shuffle\n",
    "        self.num_class = 5\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "\n",
    "        img = cv2.imread(self.images[index])\n",
    "        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
    "        mask = cv2.imread(self.masks[index],-1)\n",
    "        # general resize, normalize and to Tensor\n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "\n",
    "        return torch.FloatTensor(img), torch.LongTensor(mask)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.images)\n",
    "\n",
    "    @property\n",
    "    def pred_offset(self):\n",
    "        return 1\n",
    "\n",
    "    @property\n",
    "    def classes(self):\n",
    "        \"\"\"Category names.\"\"\"\n",
    "        return (\"0\",\"1\",\"2\",\"3\",\"4\")\n",
    "\n",
    "def _get_pairs(folder, mode='train', shuffle=True):\n",
    "\n",
    "    if mode == 'train':\n",
    "        img_folder = os.path.join(folder, '%d_imgs'%crop_size)\n",
    "        mask_folder = os.path.join(folder, '%d_label'%crop_size)\n",
    "    else:\n",
    "        img_folder = os.path.join(folder, '%d_imgs_val'%crop_size)\n",
    "        mask_folder = os.path.join(folder, '%d_label_val'%crop_size)\n",
    "\n",
    "    img_paths = glob(img_folder+\"/*\")\n",
    "    mask_paths = glob(mask_folder + \"/*\")\n",
    "    if shuffle:\n",
    "        img_paths = np.array(img_paths)\n",
    "        mask_paths = np.array(mask_paths)\n",
    "        index = [i for i in range(len(img_paths))]\n",
    "        np.random.shuffle(index)\n",
    "        img_paths = img_paths[index]\n",
    "        mask_paths = mask_paths[index]\n",
    "\n",
    "    return img_paths, mask_paths\n",
    "\n",
    "# dataset and dataloader\n",
    "input_transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize([0.2676, 0.2676, 0.2676], [0.180, 0.180, 0.180]),\n",
    "])\n",
    "data_kwargs = {'transform': input_transform}\n",
    "train_dataset = Segmentation(\"../zz\", mode='train',shuffle=True,**data_kwargs)\n",
    "val_dataset = Segmentation(\"../zz\", mode='val',shuffle=False,**data_kwargs)\n",
    "print(\"train_dataset :\",len(train_dataset))\n",
    "print(\"val_dataset:\",len(val_dataset))\n",
    "train_loader = data.DataLoader(dataset = train_dataset, batch_size= batch_size)\n",
    "val_loader = data.DataLoader(dataset = val_dataset, batch_size= batch_size)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### [Dataload的注意事项](https://zhuanlan.zhihu.com/p/66145913)\n",
    "\n",
    "在训练神经网络的时候，大部分时间都是在从磁盘中读取数据，而不是做 Backpropagation．\n",
    "\n",
    "这种症状的体现是使用 Nividia-smi 查看 GPU 使用率时，Memory-Usage 占用率很高，\n",
    "但是 GPU-Util 时常为 0%.解决方案下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "\n",
    "class data_prefetcher():\n",
    "    def __init__(self, loader):\n",
    "        self.loader = iter(loader)\n",
    "        self.stream = torch.cuda.Stream()\n",
    "        self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)\n",
    "        self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)\n",
    "        # With Amp, it isn't necessary to manually convert data to half.\n",
    "        # if args.fp16:\n",
    "        #     self.mean = self.mean.half()\n",
    "        #     self.std = self.std.half()\n",
    "        self.preload()\n",
    "\n",
    "    def preload(self):\n",
    "        try:\n",
    "            self.next_input, self.next_target = next(self.loader)\n",
    "        except StopIteration:\n",
    "            self.next_input = None\n",
    "            self.next_target = None\n",
    "            return\n",
    "        with torch.cuda.stream(self.stream):\n",
    "            self.next_input = self.next_input.cuda(non_blocking=True)\n",
    "            self.next_target = self.next_target.cuda(non_blocking=True)\n",
    "            # With Amp, it isn't necessary to manually convert data to half.\n",
    "            # if args.fp16:\n",
    "            #     self.next_input = self.next_input.half()\n",
    "            # else:\n",
    "            self.next_input = self.next_input.float()\n",
    "            self.next_input = self.next_input.sub_(self.mean).div_(self.std)\n",
    "\n",
    "            \n",
    "train_loader = DataLoader( train_dataset, batch_size=args.batch_size,  shuffle=(train_sampler is None),\n",
    "        num_workers=args.workers, pin_memory=True,  sampler=train_sampler, collate_fn=fast_collate)\n",
    "        \n",
    "prefetcher = data_prefetcher(train_loader)\n",
    "data, label = prefetcher.next()\n",
    "iteration = 0\n",
    "while data is not None:\n",
    "    iteration += 1\n",
    "    # 训练代码\n",
    "    data, label = prefetcher.next()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### [不同数据加载方式性能比较](https://www.kaggle.com/hirune924/nvidia-dali-the-fastest-data-loading#NVIDIA-DALI)\n",
    "\n",
    "* PIL + torchvision\n",
    "\n",
    "* jpeg4\n",
    "\n",
    "```\n",
    "!apt-get install libturbojpeg0\n",
    "!pip install jpeg4py\n",
    "```\n",
    "* jpeg4py + albumentations\n",
    "\n",
    "* Data Loading by using DALI\n",
    "\n",
    "![results](./figs/results.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "toc-hr-collapsed": false
   },
   "source": [
    "## 搭建网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Net(\n",
      "  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))\n",
      "  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
      "  (fc1): Linear(in_features=400, out_features=120, bias=True)\n",
      "  (fc2): Linear(in_features=120, out_features=84, bias=True)\n",
      "  (fc3): Linear(in_features=84, out_features=10, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.autograd import Variable\n",
    "import torch.nn \n",
    "import torch.nn.functional as F\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        # 1 input image channel, 6 output channels, 5*5 square convolution\n",
    "        # kernel\n",
    "\n",
    "        self.conv1 = nn.Conv2d(1, 6, 5)\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
    "        # an affine operation: y = Wx + b\n",
    "        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # max pooling over a (2, 2) window\n",
    "        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n",
    "        # If size is a square you can only specify a single number\n",
    "        x = F.max_pool2d(F.relu(self.conv2(x)), 2)\n",
    "        x = x.view(-1, self.num_flat_features(x))\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "    def num_flat_features(self, x):\n",
    "        size = x.size()[1:] # all dimensions except the batch dimension\n",
    "        num_features = 1\n",
    "        for s in size:\n",
    "            num_features *= s\n",
    "        return num_features\n",
    "    \n",
    "    def weights_init(self):\n",
    "        for module in self.modules():\n",
    "            if isinstance(module, nn.Conv2d):\n",
    "                nn.init.normal_(module.weight, mean = 0, std = 1)\n",
    "                nn.init.constant_(module.bias, 0)\n",
    "    \n",
    "model = Net()\n",
    "print(model)\n",
    "model.weights_init()\n",
    "for module in model.modules():\n",
    "    if isinstance(module, nn.Conv2d):\n",
    "        weights = module.weight\n",
    "        weights = weights.reshape(-1).detach().cpu().numpy()\n",
    "        print(module.bias)                                       # Bias to zero\n",
    "        plt.hist(weights)\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))\n",
       "  (relu1): ReLU()\n",
       "  (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))\n",
       "  (relu2): ReLU()\n",
       "  (aavgp): AdaptiveAvgPool2d(output_size=1)\n",
       ")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from collections import OrderedDict\n",
    "import torch                     # basic tensor functions\n",
    "import torch.nn as nn            # everything neural network\n",
    "\n",
    "# Simple sequential model with named layers\n",
    "layers = OrderedDict([\n",
    "    (\"conv1\", nn.Conv2d(in_channels=1, out_channels=20, kernel_size=5)),\n",
    "    (\"relu1\", nn.ReLU()),\n",
    "    (\"conv2\", nn.Conv2d(20,64,5)),\n",
    "    (\"relu2\", nn.ReLU()),\n",
    "    (\"aavgp\", nn.AdaptiveAvgPool2d(1)),\n",
    "])\n",
    "model = nn.Sequential(layers)\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### modules() vs children()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Printing children\n",
      "------------------------------\n",
      "[Sequential(\n",
      "  (0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))\n",
      "  (1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "), Linear(in_features=10, out_features=2, bias=True)]\n",
      "\n",
      "\n",
      "Printing Modules\n",
      "------------------------------\n",
      "[myNet(\n",
      "  (convBN): Sequential(\n",
      "    (0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))\n",
      "    (1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "  )\n",
      "  (linear): Linear(in_features=10, out_features=2, bias=True)\n",
      "), Sequential(\n",
      "  (0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))\n",
      "  (1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "), Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1)), BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), Linear(in_features=10, out_features=2, bias=True)]\n"
     ]
    }
   ],
   "source": [
    "class myNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.convBN =  nn.Sequential(nn.Conv2d(10,10,3), nn.BatchNorm2d(10))\n",
    "        self.linear =  nn.Linear(10,2)\n",
    "\n",
    "    def forward(self, x):\n",
    "        pass\n",
    "\n",
    "Net = myNet()\n",
    "# Net = myNet().half() ## 指定模型为半精度运行\n",
    "\n",
    "print(\"Printing children\\n------------------------------\")\n",
    "print(list(Net.children()))\n",
    "print(\"\\n\\nPrinting Modules\\n------------------------------\")\n",
    "print(list(Net.modules()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 打印网络信息"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " myNet(\n",
      "  (convBN): Sequential(\n",
      "    (0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))\n",
      "    (1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "  )\n",
      "  (linear): Linear(in_features=10, out_features=2, bias=True)\n",
      ") \n",
      "-------------------------------\n",
      "convBN Sequential(\n",
      "  (0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))\n",
      "  (1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      ") \n",
      "-------------------------------\n",
      "convBN.0 Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1)) \n",
      "-------------------------------\n",
      "convBN.1 BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) \n",
      "-------------------------------\n",
      "linear Linear(in_features=10, out_features=2, bias=True) \n",
      "-------------------------------\n"
     ]
    }
   ],
   "source": [
    "for x in Net.named_modules():\n",
    "  print(x[0], x[1], \"\\n-------------------------------\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 不同的层设置不同的学习率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class myNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Linear(10,5)\n",
    "        self.fc2 = nn.Linear(5,2)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.fc2(self.fc1(x))\n",
    "\n",
    "Net = myNet()\n",
    "# optimiser = torch.optim.SGD(Net.parameters(), lr = 0.5)\n",
    "\n",
    "#######################################################################################################\n",
    "optimiser = torch.optim.SGD([{\"params\": Net.fc1.parameters(), 'lr' : 0.001, \"momentum\" : 0.99},\n",
    "                             {\"params\": Net.fc2.parameters()}], lr = 0.01, momentum = 0.9)\n",
    "\n",
    "\n",
    "#######################################################################################################\n",
    "params_bias = []\n",
    "params_wts = []\n",
    "# seperate the bias and weights parameters\n",
    "for name, parameter in Net.named_parameters():\n",
    "  if \"bias\" in name:\n",
    "    params_bias.append(parameter)\n",
    "  elif \"weight\" in name:\n",
    "    params_wts.append(parameter)\n",
    "\n",
    "# Set the optimiser to have different hyperparameters for bias and weights\n",
    "optimiser = torch.optim.SGD([{\"params\": params_bias, 'lr' : 0.001, \"momentum\" : 0.99},\n",
    "                             {\"params\": params_wts}], lr = 0.01, momentum = 0.9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "toc-hr-collapsed": false
   },
   "source": [
    "## 获得网络信息"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 保存和加载网络模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(Net, \"net.pth\")\n",
    "Net = torch.load(\"net.pth\")\n",
    "print(Net)\n",
    "\n",
    "# Save and load only the model parameters (recommended).\n",
    "torch.save(resnet.state_dict(), 'params.ckpt')\n",
    "resnet.load_state_dict(torch.load('params.ckpt'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 获得网络的权重信息\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "layer: conv1.weight torch.Size([6, 1, 5, 5])\n",
      "layer: conv1.bias torch.Size([6])\n",
      "layer: conv2.weight torch.Size([16, 6, 5, 5])\n",
      "layer: conv2.bias torch.Size([16])\n",
      "layer: fc1.weight torch.Size([120, 400])\n",
      "layer: fc1.bias torch.Size([120])\n",
      "layer: fc2.weight torch.Size([84, 120])\n",
      "layer: fc2.bias torch.Size([84])\n",
      "layer: fc3.weight torch.Size([10, 84])\n",
      "layer: fc3.bias torch.Size([10])\n"
     ]
    }
   ],
   "source": [
    "model_dict = net.state_dict()\n",
    "for k,v in model_dict.items():\n",
    "    print(\"layer:\",k,v.size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 加载网络预训练权重模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_pretrainedweights(model, pretrained_model):\n",
    "    model_dict = model.state_dict()\n",
    "    pretrained_dict = torch.load(pretrained_model)\n",
    "    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}\n",
    "    model_dict.update(pretrained_dict)\n",
    "    model.load_state_dict(model_dict)\n",
    "    return model\n",
    "\n",
    "load_pretrainedweights(model,pretrained_model=\"model.ckpt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 检查模型是否正确加载权重"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_weight_loaded( ):\n",
    "\n",
    "    model =  EfficientNet_B4()\n",
    "    model_dict_v1 = model.state_dict()\n",
    "    model =  EfficientNet_B4()\n",
    "    model_dict_v2 = model.state_dict()\n",
    "    for k in model_dict_v1.keys():\n",
    "        try:\n",
    "            assert model_dict_v1[k].data.numpy().all() == model_dict_v2[k].data.numpy().all()\n",
    "        except AssertionError as e:\n",
    "            print(\"failed to load weights!\")\n",
    "            break\n",
    "    else:\n",
    "        print(\"--> model weight loading is  successful!\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 模型权重初始化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import  torch.nn from init\n",
    "def weights_init_normal(m):\n",
    "    classname = m.__class__.__name__\n",
    "    if classname.find('Conv')!=\"-1\":\n",
    "        init.normal(m.weight.data, 0.0, 0.02)\n",
    "    elif classname.find('Linear')!=\"-1\":\n",
    "        init.normal(m.weight.data, 0.0, 0.02)\n",
    "    elif classname.find('BatchNorm2d')!=\"-1\":\n",
    "        init.normal(m.weight.data, 1.0, 0.02)\n",
    "        init.constant(m.bias.data, 0.0)\n",
    "        \n",
    "model.apply(weights_init_normal)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 加载网络模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_state_dict(torch.load(weight_path, map_location=lambda storage, loc: storage))\n",
    "model.load_state_dict(torch.load(weight_path, map_location='cpu'))\n",
    "model = model.module#才是你的模型            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 获得网络的中间某一层的输出和参数量\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Net(\n",
      "  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))\n",
      "  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
      "  (fc1): Linear(in_features=400, out_features=120, bias=True)\n",
      "  (fc2): Linear(in_features=120, out_features=84, bias=True)\n",
      "  (fc3): Linear(in_features=84, out_features=10, bias=True)\n",
      ")\n",
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Conv2d-1            [-1, 6, 28, 28]             156\n",
      "            Conv2d-2           [-1, 16, 10, 10]           2,416\n",
      "            Linear-3                  [-1, 120]          48,120\n",
      "            Linear-4                   [-1, 84]          10,164\n",
      "            Linear-5                   [-1, 10]             850\n",
      "================================================================\n",
      "Total params: 61,706\n",
      "Trainable params: 61,706\n",
      "Non-trainable params: 0\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 0.00\n",
      "Forward/backward pass size (MB): 0.05\n",
      "Params size (MB): 0.24\n",
      "Estimated Total Size (MB): 0.29\n",
      "----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "print(net)\n",
    "from torchsummary import summary\n",
    "summary(net,(1,32,32))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数据并行"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.device_count() > 1:\n",
    "    model = nn.DataParallel(model)\n",
    "print(\"model parameters device:\", next(model.parameters()).is_cuda)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 迁移学习"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download and load the pretrained ResNet-18.\n",
    "from torchvision.models import resnet18\n",
    "\n",
    "resnet = resnet18(pretrained=True)\n",
    "# If you want to finetune only the top layer of the model, set as below.\n",
    "for param in resnet.parameters():\n",
    "    param.requires_grad = False\n",
    "\n",
    "# Replace the top layer for finetuning.\n",
    "resnet.fc = nn.Linear(resnet.fc.in_features, 100)  # 100 is an example.\n",
    "\n",
    "# Forward pass.\n",
    "images = torch.randn(64, 3, 224, 224)\n",
    "outputs = resnet(images)\n",
    "print (outputs.size())     # (64, 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "toc-hr-collapsed": true
   },
   "source": [
    "## hook\n",
    "\n",
    "由于pytorch会自动舍弃图计算的中间结果，所以想要获取这些数值就需要使用钩子函数。\n",
    "钩子函数包括Variable的钩子和nn.Module钩子，用法相似。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### register_hook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.autograd import Variable\n",
    "\n",
    "grad_list = []\n",
    "def print_grad(grad):\n",
    "    grad_list.append(grad)\n",
    "\n",
    "x = Variable(torch.ones(2, 1), requires_grad=True)\n",
    "y = x + 2\n",
    "z = torch.mean(torch.pow(y, 2))\n",
    "lr = 1e-3\n",
    "hook = y.register_hook(print)\n",
    "# y.register_hook(print_grad)\n",
    "z.backward()\n",
    "x.data -= lr * x.grad.data\n",
    "print(grad_list)\n",
    "hook.remove()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Linear(in_features=160, out_features=5, bias=True)\n",
      "------------Input Grad------------\n",
      "torch.Size([5])\n",
      "torch.Size([5])\n",
      "------------Output Grad------------\n",
      "torch.Size([5])\n",
      "\n",
      "\n",
      "Conv2d(3, 10, kernel_size=(2, 2), stride=(2, 2))\n",
      "------------Input Grad------------\n",
      "None found for Gradient\n",
      "torch.Size([10, 3, 2, 2])\n",
      "torch.Size([10])\n",
      "------------Output Grad------------\n",
      "torch.Size([1, 10, 4, 4])\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import torch \n",
    "import torch.nn as nn\n",
    "\n",
    "def hook_fn(m, i, o):\n",
    "    print(m)\n",
    "    print(\"------------Input Grad------------\")\n",
    "\n",
    "    for grad in i:\n",
    "        try:\n",
    "            print(grad.shape)\n",
    "        except AttributeError: \n",
    "            print (\"None found for Gradient\")\n",
    "\n",
    "    print(\"------------Output Grad------------\")\n",
    "    for grad in o:  \n",
    "        try:\n",
    "            print(grad.shape)\n",
    "        except AttributeError: \n",
    "            print (\"None found for Gradient\")\n",
    "    print(\"\\n\")\n",
    "\n",
    "class myNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv = nn.Conv2d(3,10,2, stride = 2)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.flatten = lambda x: x.view(-1)\n",
    "        self.fc1 = nn.Linear(160,5)\n",
    "   \n",
    "  \n",
    "    def forward(self, x):\n",
    "        x = self.relu(self.conv(x))\n",
    "        return self.fc1(self.flatten(x))\n",
    "\n",
    "net = myNet()\n",
    "\n",
    "net.conv.register_backward_hook(hook_fn)\n",
    "net.fc1.register_backward_hook(hook_fn)\n",
    "inp = torch.randn(1,3,8,8)\n",
    "out = net(inp)\n",
    "(1 - out.mean()).backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualisation = {}\n",
    "\n",
    "inp = torch.randn(1,3,8,8)\n",
    "\n",
    "def hook_fn(m, i, o):\n",
    "    visualisation[m] = o \n",
    "\n",
    "net = myNet()\n",
    "\n",
    "for name, layer in net._modules.items():\n",
    "    layer.register_forward_hook(hook_fn)\n",
    "\n",
    "out = net(inp) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "toc-hr-collapsed": false
   },
   "source": [
    "## [常用的损失函数](https://github.com/CoinCheung/pytorch-loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def dice_loss(logits,true, eps=1e-7):\n",
    "    \"\"\"Computes the Sørensen–Dice loss.\n",
    "    Note that PyTorch optimizers minimize a loss. In this\n",
    "    case, we would like to maximize the dice loss so we\n",
    "    return the negated dice loss.\n",
    "    Args:\n",
    "        true: a tensor of shape [B, 1, H, W].\n",
    "        logits: a tensor of shape [B, C, H, W]. Corresponds to\n",
    "            the raw output or logits of the model.\n",
    "        eps: added to the denominator for numerical stability.\n",
    "    Returns:\n",
    "        dice_loss: the Sørensen–Dice loss.\n",
    "    \"\"\"\n",
    "    num_classes = logits.shape[1]\n",
    "    if num_classes == 1:\n",
    "        true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)]\n",
    "        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()\n",
    "        true_1_hot_f = true_1_hot[:, 0:1, :, :]\n",
    "        true_1_hot_s = true_1_hot[:, 1:2, :, :]\n",
    "        true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)\n",
    "        pos_prob = torch.sigmoid(logits)\n",
    "        neg_prob = 1 - pos_prob\n",
    "        probas = torch.cat([pos_prob, neg_prob], dim=1)\n",
    "    else:\n",
    "        true_1_hot = torch.eye(num_classes)[true.squeeze(1)]\n",
    "        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()\n",
    "        probas = F.softmax(logits, dim=1)\n",
    "    true_1_hot = true_1_hot.type(logits.type())\n",
    "    dims = (0,) + tuple(range(2, true.ndimension()))\n",
    "    intersection = torch.sum(probas * true_1_hot, dims)\n",
    "    cardinality = torch.sum(probas + true_1_hot, dims)\n",
    "    dice_loss = (2. * intersection / (cardinality + eps)).mean()\n",
    "    return (1 - dice_loss)\n",
    "\n",
    "\n",
    "criterion = nn.CrossEntropyLoss().to(device)\n",
    "criterion = nn.BCELoss().to(device)\n",
    "criterion = dice_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "class CriterionWithLabelSmoothing(nn.Module):\n",
    "    \n",
    "    def __init__(self, criterion, alpha=0.2):\n",
    "        super(CriterionWithLabelSmoothing, self).__init__()\n",
    "        self.criterion = criterion\n",
    "        if self.criterion.reduction != 'none':\n",
    "            raise ValueError(\"Input criterion should have reduction equal none\")\n",
    "        self.alpha = alpha\n",
    "    \n",
    "    def forward(self, logits, targets):\n",
    "        loss = self.criterion(logits, targets)\n",
    "        log_probs = torch.log_softmax(logits, dim=1)\n",
    "        klloss = -log_probs.mean(dim=1)        \n",
    "        out = (1.0 - self.alpha) * loss + self.alpha * klloss\n",
    "        return out.mean(dim=0)\n",
    "\n",
    "def get_criterion(alpha):\n",
    "    return CriterionWithLabelSmoothing(nn.CrossEntropyLoss(reduction='none'), alpha=0.2)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### OHEM_crossEntropy_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class OHEM_crossEntropy_loss(nn.Module):\n",
    "    def __init__(self, top_k=0.7):\n",
    "        super(OHEM_crossEntropy_loss, self).__init__()\n",
    "        self.loss = nn.CrossEntropyLoss(reduction= 'none')\n",
    "        self.top_k = top_k\n",
    "    \n",
    "    def forward(self, logits, targets):\n",
    "        loss = self.loss(logits, targets)\n",
    "        valid_loss, idx = torch.topk(loss, int(self.top_k * loss.size()[0]))  \n",
    "        return torch.mean(valid_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### weights_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## 0,1类更难，设置权重为1.2，2类更容易，设置权重为0.8\n",
    "weights = [1.2, 1.2, 0.8]\n",
    "class_weights = torch.FloatTensor(weights).to(device)\n",
    "criterion = torch.nn.CrossEntropyLoss(weight=class_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### label_smoothing_v1_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "class LabelSmoothSoftmaxCEV1_Loss(nn.Module):\n",
    "\n",
    "    def __init__(self, lb_smooth=0.1, reduction='mean', lb_ignore=-100):\n",
    "        super(LabelSmoothSoftmaxCEV1_Loss, self).__init__()\n",
    "        self.lb_smooth = lb_smooth\n",
    "        self.reduction = reduction\n",
    "        self.lb_ignore = lb_ignore\n",
    "        self.log_softmax = nn.LogSoftmax(dim=1)\n",
    "\n",
    "    def forward(self, logits, label):\n",
    "        '''\n",
    "        args: logits: tensor of shape (N, C, H, W)\n",
    "        args: label: tensor of shape(N, H, W)\n",
    "        '''\n",
    "        # overcome ignored label\n",
    "        with torch.no_grad():\n",
    "            num_classes = logits.size(1)\n",
    "            label = label.clone().detach()\n",
    "            ignore = label == self.lb_ignore\n",
    "            n_valid = (ignore == 0).sum()\n",
    "            label[ignore] = 0\n",
    "            lb_pos, lb_neg = 1. - self.lb_smooth, self.lb_smooth / num_classes\n",
    "            label = torch.empty_like(logits).fill_(\n",
    "                lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach()\n",
    "\n",
    "        logs = self.log_softmax(logits)\n",
    "        loss = -torch.sum(logs * label, dim=1)\n",
    "        loss[ignore] = 0\n",
    "        if self.reduction == 'mean':\n",
    "            loss = loss.sum() / n_valid\n",
    "        if self.reduction == 'sum':\n",
    "            loss = loss.sum()\n",
    "\n",
    "        return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### label_smoothing_v2_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LabelSmoothSoftmaxCEFunction(torch.autograd.Function):\n",
    "\n",
    "    @staticmethod\n",
    "    def forward(ctx, logits, label, lb_smooth, reduction, lb_ignore):\n",
    "        # prepare label\n",
    "        num_classes = logits.size(1)\n",
    "        label = label.clone().detach()\n",
    "        ignore = label == lb_ignore\n",
    "        n_valid = (ignore == 0).sum()\n",
    "        label[ignore] = 0\n",
    "        lb_pos, lb_neg = 1. - lb_smooth, lb_smooth / num_classes\n",
    "        label = torch.empty_like(logits).fill_(\n",
    "            lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach()\n",
    "\n",
    "        scores = torch.softmax(logits, dim=1)\n",
    "        logs = torch.log(scores)\n",
    "\n",
    "        ctx.scores = scores\n",
    "        ctx.label = label\n",
    "        ctx.reduction = reduction\n",
    "        ctx.n_valid = n_valid\n",
    "\n",
    "        loss = -torch.sum(logs * label, dim=1)\n",
    "        loss[ignore] = 0\n",
    "        if reduction == 'mean':\n",
    "            loss = loss.sum() / n_valid\n",
    "        if reduction == 'sum':\n",
    "            loss = loss.sum()\n",
    "        return loss\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_output):\n",
    "        scores = ctx.scores\n",
    "        label = ctx.label\n",
    "        reduction = ctx.reduction\n",
    "        n_valid = ctx.n_valid\n",
    "        grad = grad_output * (scores - label)\n",
    "        if reduction == 'mean':\n",
    "            grad /= n_valid\n",
    "        return grad, None, None, None, None, None\n",
    "\n",
    "\n",
    "class LabelSmoothSoftmaxCEV2_losss(nn.Module):\n",
    "\n",
    "    def __init__(self, lb_smooth=0.1, reduction='mean', lb_ignore=-100):\n",
    "        super(LabelSmoothSoftmaxCEV2_losss, self).__init__()\n",
    "        self.lb_smooth = lb_smooth\n",
    "        self.reduction = reduction\n",
    "        self.lb_ignore = lb_ignore\n",
    "\n",
    "    def forward(self, logits, label):\n",
    "        return LabelSmoothSoftmaxCEFunction.apply(\n",
    "                logits, label, self.lb_smooth, self.reduction, self.lb_ignore)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Focal loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FocalLoss(nn.Module):\n",
    "    def __init__(self,\n",
    "                 alpha=0.25,\n",
    "                 gamma=2,\n",
    "                 reduction='mean',\n",
    "                 ignore_lb=255):\n",
    "        super(FocalLoss, self).__init__()\n",
    "        self.alpha = alpha\n",
    "        self.gamma = gamma\n",
    "        self.reduction = reduction\n",
    "        self.ignore_lb = ignore_lb\n",
    "        self.crit = nn.BCEWithLogitsLoss(reduction='none')\n",
    "\n",
    "    def forward(self, logits, label):\n",
    "        '''\n",
    "        args: logits: tensor of shape (N, C, H, W)\n",
    "        args: label: tensor of shape(N, H, W)\n",
    "        '''\n",
    "        # overcome ignored label\n",
    "        with torch.no_grad():\n",
    "            label = label.clone().detach()\n",
    "            ignore = label == self.ignore_lb\n",
    "            n_valid = (ignore == 0).sum()\n",
    "            label[ignore] = 0\n",
    "            lb_one_hot = torch.zeros_like(logits).scatter_(\n",
    "                1, label.unsqueeze(1), 1).detach()\n",
    "            alpha = torch.empty_like(logits).fill_(1 - self.alpha)\n",
    "            alpha[lb_one_hot == 1] = self.alpha\n",
    "\n",
    "        # compute loss\n",
    "        probs = torch.sigmoid(logits)\n",
    "        pt = torch.where(lb_one_hot == 1, probs, 1 - probs)\n",
    "        ce_loss = self.crit(logits, lb_one_hot)\n",
    "        loss = (alpha * torch.pow(1 - pt, self.gamma) * ce_loss).sum(dim=1)\n",
    "        loss[ignore == 1] = 0\n",
    "        if self.reduction == 'mean':\n",
    "            loss = loss.sum() / n_valid\n",
    "        if self.reduction == 'sum':\n",
    "            loss = loss.sum()\n",
    "        return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dual_Focal loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Dual_Focal_loss(nn.Module):\n",
    "    '''\n",
    "    This loss is proposed in this paper: https://arxiv.org/abs/1909.11932\n",
    "    It does not work in my projects, hope it will work well in your projects.\n",
    "    Hope you can correct me if there are any mistakes in the implementation.\n",
    "    '''\n",
    "\n",
    "    def __init__(self, ignore_lb=255, eps=1e-5, reduction='mean'):\n",
    "        super(Dual_Focal_loss, self).__init__()\n",
    "        self.ignore_lb = ignore_lb\n",
    "        self.eps = eps\n",
    "        self.reduction = reduction\n",
    "        self.mse = nn.MSELoss(reduction='none')\n",
    "\n",
    "    def forward(self, logits, label):\n",
    "        ignore = label.data.cpu() == self.ignore_lb\n",
    "        n_valid = (ignore == 0).sum()\n",
    "        label = label.clone()\n",
    "        label[ignore] = 0\n",
    "        lb_one_hot = logits.data.clone().zero_().scatter_(1, label.unsqueeze(1), 1).detach()\n",
    "\n",
    "        pred = torch.softmax(logits, dim=1)\n",
    "        loss = -torch.log(self.eps + 1. - self.mse(pred, lb_one_hot)).sum(dim=1)\n",
    "        loss[ignore] = 0\n",
    "        if self.reduction == 'mean':\n",
    "            loss = loss.sum() / n_valid\n",
    "        elif self.reduction == 'sum':\n",
    "            loss = loss.sum()\n",
    "        elif self.reduction == 'none':\n",
    "            loss = loss\n",
    "        return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dice loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GeneralizedSoftDiceLoss(nn.Module):\n",
    "    def __init__(self,\n",
    "                 p=1,\n",
    "                 smooth=1,\n",
    "                 reduction='mean',\n",
    "                 weight=None,\n",
    "                 ignore_lb=255):\n",
    "        super(GeneralizedSoftDiceLoss, self).__init__()\n",
    "        self.p = p\n",
    "        self.smooth = smooth\n",
    "        self.reduction = reduction\n",
    "        self.weight = None if weight is None else torch.tensor(weight)\n",
    "        self.ignore_lb = ignore_lb\n",
    "\n",
    "    def forward(self, logits, label):\n",
    "        '''\n",
    "        args: logits: tensor of shape (N, C, H, W)\n",
    "        args: label: tensor of shape(N, H, W)\n",
    "        '''\n",
    "        # overcome ignored label\n",
    "        ignore = label.data.cpu() == self.ignore_lb\n",
    "        label = label.clone()\n",
    "        label[ignore] = 0\n",
    "        lb_one_hot = torch.zeros_like(logits).scatter_(1, label.unsqueeze(1), 1)\n",
    "        ignore = ignore.nonzero()\n",
    "        _, M = ignore.size()\n",
    "        a, *b = ignore.chunk(M, dim=1)\n",
    "        lb_one_hot[[a, torch.arange(lb_one_hot.size(1)).long(), *b]] = 0\n",
    "        lb_one_hot = lb_one_hot.detach()\n",
    "\n",
    "        # compute loss\n",
    "        probs = torch.sigmoid(logits)\n",
    "        numer = torch.sum((probs*lb_one_hot), dim=(2, 3))\n",
    "        denom = torch.sum(probs.pow(self.p)+lb_one_hot.pow(self.p), dim=(2, 3))\n",
    "        if not self.weight is None:\n",
    "            numer = numer * self.weight.view(1, -1)\n",
    "            denom = denom * self.weight.view(1, -1)\n",
    "        numer = torch.sum(numer, dim=1)\n",
    "        denom = torch.sum(denom, dim=1)\n",
    "        loss = 1 - (2*numer+self.smooth)/(denom+self.smooth)\n",
    "\n",
    "        if self.reduction == 'mean':\n",
    "            loss = loss.mean()\n",
    "        return loss\n",
    "\n",
    "\n",
    "class BatchSoftDiceLoss(nn.Module):\n",
    "    def __init__(self,\n",
    "                 p=1,\n",
    "                 smooth=1,\n",
    "                 weight=None,\n",
    "                 ignore_lb=255):\n",
    "        super(BatchSoftDiceLoss, self).__init__()\n",
    "        self.p = p\n",
    "        self.smooth = smooth\n",
    "        self.weight = None if weight is None else torch.tensor(weight)\n",
    "        self.ignore_lb = ignore_lb\n",
    "\n",
    "    def forward(self, logits, label):\n",
    "        '''\n",
    "        args: logits: tensor of shape (N, C, H, W)\n",
    "        args: label: tensor of shape(N, H, W)\n",
    "        '''\n",
    "        # overcome ignored label\n",
    "        ignore = label.data.cpu() == self.ignore_lb\n",
    "        label = label.clone()\n",
    "        label[ignore] = 0\n",
    "        lb_one_hot = torch.zeros_like(logits).scatter_(1, label.unsqueeze(1), 1)\n",
    "        ignore = ignore.nonzero()\n",
    "        _, M = ignore.size()\n",
    "        a, *b = ignore.chunk(M, dim=1)\n",
    "        lb_one_hot[[a, torch.arange(lb_one_hot.size(1)).long(), *b]] = 0\n",
    "        lb_one_hot = lb_one_hot.detach()\n",
    "\n",
    "        # compute loss\n",
    "        probs = torch.sigmoid(logits)\n",
    "        numer = torch.sum((probs*lb_one_hot), dim=(2, 3))\n",
    "        denom = torch.sum(probs.pow(self.p)+lb_one_hot.pow(self.p), dim=(2, 3))\n",
    "        if not self.weight is None:\n",
    "            numer = numer * self.weight.view(1, -1)\n",
    "            denom = denom * self.weight.view(1, -1)\n",
    "        numer = torch.sum(numer)\n",
    "        denom = torch.sum(denom)\n",
    "        loss = 1 - (2*numer+self.smooth)/(denom+self.smooth)\n",
    "        return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    " \n",
    "class DiceLoss(nn.Module):\n",
    "\tdef __init__(self):\n",
    "\t\tsuper(DiceLoss, self).__init__()\n",
    " \n",
    "\tdef\tforward(self, input, target):\n",
    "\t\tN = target.size(0)\n",
    "\t\tsmooth = 1\n",
    " \n",
    "\t\tinput_flat = input.view(N, -1)\n",
    "\t\ttarget_flat = target.view(N, -1)\n",
    " \n",
    "\t\tintersection = input_flat * target_flat\n",
    " \n",
    "\t\tloss =  (2 * intersection.sum(1) + smooth) / (input_flat.sum(1) + target_flat.sum(1) + smooth)\n",
    "\t\tloss = 1 - loss.sum() / N\n",
    " \n",
    "\t\treturn loss\n",
    " \n",
    "class MulticlassDiceLoss(nn.Module):\n",
    "\t\"\"\"\n",
    "\trequires one hot encoded target. Applies DiceLoss on each class iteratively.\n",
    "\trequires input.shape[0:1] and target.shape[0:1] to be (N, C) where N is\n",
    "\t  batch size and C is number of classes\n",
    "\t\"\"\"\n",
    "\tdef __init__(self):\n",
    "\t\tsuper(MulticlassDiceLoss, self).__init__()\n",
    " \n",
    "\tdef forward(self, input, target, weights=None):\n",
    " \n",
    "\t\tC = target.shape[1]\n",
    " \n",
    "\t\t# if weights is None:\n",
    "\t\t# \tweights = torch.ones(C) #uniform weights for all classes\n",
    " \n",
    "\t\tdice = DiceLoss()\n",
    "\t\ttotalLoss = 0\n",
    " \n",
    "\t\tfor i in range(C):\n",
    "\t\t\tdiceLoss = dice(input[:,i], target[:,i])\n",
    "\t\t\tif weights is not None:\n",
    "\t\t\t\tdiceLoss *= weights[i]\n",
    "\t\t\ttotalLoss += diceLoss\n",
    " \n",
    "\t\treturn totalLoss\n",
    "\n",
    "def IOU(pred , mask):\n",
    "    N = mask.size(0)\n",
    "    eps=1e-6\n",
    "    input_flat = pred.view(N, -1)\n",
    "    target_flat = mask.view(N, -1)\n",
    "    intersection = input_flat * target_flat\n",
    "    iou = (intersection.sum(1) + eps) / (input_flat.sum(1) + target_flat.sum(1)-intersection.sum(1) + eps)\n",
    "    return iou.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### AMSoftmax loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AMSoftmax(nn.Module):\n",
    "    def __init__(self,\n",
    "                 in_feats,\n",
    "                 n_classes=10,\n",
    "                 m=0.3,\n",
    "                 s=15):\n",
    "        super(AMSoftmax, self).__init__()\n",
    "        self.m = m\n",
    "        self.s = s\n",
    "        self.in_feats = in_feats\n",
    "        self.W = torch.nn.Parameter(torch.randn(in_feats, n_classes), requires_grad=True)\n",
    "        self.ce = nn.CrossEntropyLoss()\n",
    "        nn.init.xavier_normal_(self.W, gain=1)\n",
    "\n",
    "    def forward(self, x, lb):\n",
    "        assert x.size()[0] == lb.size()[0]\n",
    "        assert x.size()[1] == self.in_feats\n",
    "        x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12)\n",
    "        x_norm = torch.div(x, x_norm)\n",
    "        w_norm = torch.norm(self.W, p=2, dim=0, keepdim=True).clamp(min=1e-12)\n",
    "        w_norm = torch.div(self.W, w_norm)\n",
    "        costh = torch.mm(x_norm, w_norm)\n",
    "        lb_view = lb.view(-1, 1)\n",
    "        if lb_view.is_cuda: lb_view = lb_view.cpu()\n",
    "        delt_costh = torch.zeros(costh.size()).scatter_(1, lb_view, self.m)\n",
    "        if x.is_cuda: delt_costh = delt_costh.cuda()\n",
    "        costh_m = costh - delt_costh\n",
    "        costh_m_s = self.s * costh_m\n",
    "        loss = self.ce(costh_m_s, lb)\n",
    "        return loss\n",
    "\n",
    "\n",
    "criteria = AMSoftmax(1024, 10)\n",
    "a = torch.randn(20, 1024)\n",
    "lb = torch.randint(0, 10, (20, ), dtype=torch.long)\n",
    "loss = criteria(a, lb)\n",
    "loss.backward()\n",
    "\n",
    "print(loss.detach().numpy())\n",
    "print(list(criteria.parameters())[0].shape)\n",
    "print(type(next(criteria.parameters())))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### mixup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py\n",
    "\n",
    "def mixup_data(x, y, alpha=1.0, use_cuda=True):\n",
    "    '''Returns mixed inputs, pairs of targets, and lambda'''\n",
    "    if alpha > 0:\n",
    "        lam = np.random.beta(alpha, alpha)\n",
    "    else:\n",
    "        lam = 1\n",
    "\n",
    "    batch_size = x.size()[0]\n",
    "    if use_cuda:\n",
    "        index = torch.randperm(batch_size).cuda()\n",
    "    else:\n",
    "        index = torch.randperm(batch_size)\n",
    "\n",
    "    mixed_x = lam * x + (1 - lam) * x[index, :]\n",
    "    y_a, y_b = y, y[index]\n",
    "    return mixed_x, y_a, y_b, lam\n",
    "\n",
    "\n",
    "def mixup_criterion(criterion, pred, y_a, y_b, lam):\n",
    "    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "inputs, targets = inputs.cuda(), targets.cuda()\n",
    "inputs, targets_a, targets_b, lam = mixup_data(inputs, targets,args.alpha, use_cuda)\n",
    "inputs, targets_a, targets_b = map(Variable, (inputs,targets_a, targets_b))\n",
    "outputs = net(inputs)\n",
    "loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)\n",
    "train_loss += loss.data[0]\n",
    "_, predicted = torch.max(outputs.data, 1)\n",
    "total += targets.size(0)\n",
    "correct += (lam * predicted.eq(targets_a.data).cpu().sum().float()\n",
    "            + (1 - lam) * predicted.eq(targets_b.data).cpu().sum().float())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### NLLLoss和CrossEntropyLoss(Softmax–>Log–>NLLLoss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x: tensor([[-0.3314,  0.3583,  0.2077],\n",
      "        [ 1.4749, -0.0436,  0.2306],\n",
      "        [ 1.2537,  0.2282, -0.0145],\n",
      "        [-1.5693,  0.3295,  0.2172]])\n",
      "softmax x: tensor([[0.2124, 0.4234, 0.3642],\n",
      "        [0.6635, 0.1453, 0.1912],\n",
      "        [0.6098, 0.2187, 0.1716],\n",
      "        [0.0733, 0.4893, 0.4374]])\n",
      "log: tensor([[-1.5492, -0.8595, -1.0101],\n",
      "        [-0.4102, -1.9287, -1.6546],\n",
      "        [-0.4947, -1.5202, -1.7629],\n",
      "        [-2.6135, -0.7147, -0.8269]])\n",
      "target type: torch.LongTensor\n",
      "pred type: torch.FloatTensor\n",
      "NLLLoss: tensor(1.3597)\n",
      "CrossEntropyLoss : tensor(1.3597)\n",
      "CrossEntropyLoss : tensor([1.5492, 1.6546, 1.5202, 0.7147])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "x=torch.randn(4,3)\n",
    "print(\"x:\", x)\n",
    "sm = nn.Softmax(dim =1)\n",
    "print(\"softmax x:\", sm(x))\n",
    "print(\"log:\", torch.log(sm(x)))\n",
    "\n",
    "loss = nn.NLLLoss()\n",
    "target = torch.tensor([0,2,1,1])\n",
    "print(\"target type:\",target.type())\n",
    "print(\"pred type:\",x.type())\n",
    "print(\"NLLLoss:\",loss(torch.log(sm(x)),target))\n",
    "\n",
    "loss = nn.CrossEntropyLoss() # Softmax–Log–NLLLoss\n",
    "print(\"CrossEntropyLoss :\", loss(x, target))\n",
    "\n",
    "loss = nn.CrossEntropyLoss(reduction= 'none') # Softmax–Log–NLLLoss\n",
    "print(\"CrossEntropyLoss :\", loss(x, target))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### BCELOSS和BCEWithLogitsLoss(Sigmoid->BCELoss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x: tensor([[-1.1235, -0.0948, -0.8286],\n",
      "        [ 0.8362, -1.9370,  0.7749],\n",
      "        [ 1.1140, -0.2855,  1.0280]])\n",
      "Sigmoid x: tensor([[0.2454, 0.4763, 0.3039],\n",
      "        [0.6977, 0.1260, 0.6846],\n",
      "        [0.7529, 0.4291, 0.7365]])\n",
      "BCELoss : tensor(0.8267)\n",
      "CrossEntropyLoss : tensor(0.8267)\n"
     ]
    }
   ],
   "source": [
    "import torch.nn as nn\n",
    "x=torch.randn(3,3)\n",
    "print(\"x:\", x)\n",
    "sm = nn.Sigmoid()\n",
    "print(\"Sigmoid x:\", sm(x))\n",
    "\n",
    "\n",
    "target = torch.FloatTensor([[0,1,1],[1,0,0],[0,1,0]])\n",
    "loss = nn.BCELoss() # Sigmoid-BCELoss\n",
    "print(\"BCELoss :\", loss(sm(x), target))\n",
    "\n",
    "loss = nn.BCEWithLogitsLoss() # Sigmoid-BCELoss\n",
    "print(\"CrossEntropyLoss :\", loss(x, target))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [5/60], Loss: 33.5815\n",
      "Epoch [10/60], Loss: 13.8363\n",
      "Epoch [15/60], Loss: 5.8369\n",
      "Epoch [20/60], Loss: 2.5958\n",
      "Epoch [25/60], Loss: 1.2825\n",
      "Epoch [30/60], Loss: 0.7501\n",
      "Epoch [35/60], Loss: 0.5341\n",
      "Epoch [40/60], Loss: 0.4463\n",
      "Epoch [45/60], Loss: 0.4103\n",
      "Epoch [50/60], Loss: 0.3955\n",
      "Epoch [55/60], Loss: 0.3891\n",
      "Epoch [60/60], Loss: 0.3862\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# 1:自定义损失函数\n",
    "def my_mse_loss(x, y):\n",
    "    return torch.mean(torch.pow((x - y), 2))\n",
    "\n",
    "## 2: 继承nn.Mdule\n",
    "class My_loss(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "    def forward(self, x, y):\n",
    "        return torch.mean(torch.pow((x - y), 2))\n",
    "\n",
    "    \n",
    "\n",
    "\n",
    "x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],\n",
    "                    [9.779], [6.182], [7.59], [2.167], [7.042],\n",
    "                    [10.791], [5.313], [7.997], [3.1]], dtype=np.float32)\n",
    "\n",
    "y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],\n",
    "                    [3.366], [2.596], [2.53], [1.221], [2.827],\n",
    "                    [3.465], [1.65], [2.904], [1.3]], dtype=np.float32)\n",
    "\n",
    "# 将numpy数据转化为torch的张量\n",
    "inputs = torch.from_numpy(x_train)\n",
    "targets = torch.from_numpy(y_train)\n",
    "\n",
    "input_size = 1\n",
    "output_size = 1\n",
    "num_epochs = 60\n",
    "learning_rate = 0.001\n",
    "\n",
    "# 第三步： 构建模型，构建一个一层的网络模型\n",
    "model = nn.Linear(input_size, output_size)\n",
    "\n",
    "# 与模型相关的配置、损失函数、优化方式\n",
    "# 使用自定义函数，等价于criterion = nn.MSELoss()\n",
    "criterion = My_loss()\n",
    "\n",
    "# 定义迭代优化算法， 使用的是随机梯度下降算法\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)\n",
    "loss_history = []\n",
    "# 第四步：训练模型，迭代训练\n",
    "for epoch in range(num_epochs):\n",
    "    #  前向传播计算网络结构的输出结果\n",
    "    outputs = model(inputs)\n",
    "\n",
    "    # 计算损失函数\n",
    "    loss = criterion(outputs, targets)\n",
    "\n",
    "    # 反向传播更新参数，三步策略，归零梯度——>反向传播——>更新参数\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    # 打印训练信息和保存loss\n",
    "    loss_history.append(loss.item())\n",
    "    if (epoch + 1) % 5 == 0:\n",
    "        print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, loss.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXhUVb7u8e8PiIRRFLFFICYiyiQECAKCijKIgMNBUbppbb120w6tdKvQKA44gNB6HPqKcnBo9JrWgyiIgjMgCIoMgoyNIEEiqIAyxAgGWPePCkWqqIRKUsneVfV+nocn2at2qn4U4c3K2muvZc45REQk/lXxugAREYkNBbqISIJQoIuIJAgFuohIglCgi4gkiGpevfAJJ5zg0tPTvXp5EZG4tGTJku3OuQaRHvMs0NPT01m8eLFXLy8iEpfMbFNxj2nIRUQkQSjQRUQShAJdRCRBeDaGHklBQQG5ubns3bvX61IESE1NpXHjxqSkpHhdiohEwVeBnpubS506dUhPT8fMvC4nqTnn2LFjB7m5uWRkZHhdjohEwVdDLnv37qV+/foKcx8wM+rXr6/flkTiiK8CHVCY+4j+LUTii+8CXUQkUe0tOMBjH6xjy85fKuT5FehhcnNzufTSS2nWrBlNmzZl6NCh/PrrrxHP3bJlC1dcccVRn7Nv377s3LmzTPWMGjWKRx999Kjn1a5du8THd+7cydNPP12mGkSk/CYv3kzze97lnx99xdx12yrkNeI70LOzIT0dqlQJfMzOLtfTOecYMGAAl112GV999RXr1q0jLy+PkSNHHnHu/v37Ofnkk5kyZcpRn3fmzJnUq1evXLWVlwJdxBu7fikgfcQMhk/5EoDLMk9m0FlpFfJa8Rvo2dkwZAhs2gTOBT4OGVKuUJ81axapqalcd911AFStWpXHH3+cF154gfz8fCZNmsTAgQO5+OKL6d27Nzk5ObRu3RqA/Px8rrzyStq0acNVV11Fp06dgksbpKens337dnJycmjRogV/+tOfaNWqFb179+aXXwK/ej377LN07NiRtm3bcvnll5Ofn19irRs3bqRLly507NiRe+65J9iel5dHjx49aN++PWeeeSZvvvkmACNGjGDDhg1kZmYybNiwYs8TkdiZ8PEG2t7/fvB47rDzeWJQuwp7vfgN9JEjITz08vMD7WW0atUqOnToENJWt25d0tLSWL9+PQCffvopL774IrNmzQo57+mnn+a4447jyy+/5J577mHJkiURX+Orr77i5ptvZtWqVdSrV4/XX38dgAEDBrBo0SKWL19OixYteP7550usdejQodx4440sWrSIk046KdiemprK1KlTWbp0KbNnz+b222/HOcfYsWNp2rQpy5Yt45FHHin2PBEpvx927yV9xAzGvrMWgD+feyo5Y/uRVr9mhb6ur+ahl8o335SuPQrOuYgzO4q29+rVi+OPP/6Icz755BOGDh0KQOvWrWnTpk3E18jIyCAzMxOADh06kJOTA8DKlSu5++672blzJ3l5eVx44YUl1jp//vzgD4Orr76av//978Fa77rrLubOnUuVKlX49ttv+f777yP+nSKdV/SHg4iU3oNvr+b5TzYGjxeN7EmDOtUr5bXjN9DT0gLDLJHay6hVq1bBkDxk9+7dbN68maZNm7JkyRJq1aoV8Wuj7d1Wr374H7Zq1arBIZdrr72WadOm0bZtWyZNmsScOXOO+lyRfvhkZ2ezbds2lixZQkpKCunp6RHnkkd7nohEJ2f7z3R/dE7weGTfFvzp3FMrtYb4HXIZPRpqhv36UrNmoL2MevToQX5+Pi+99BIABw4c4Pbbb+faa6+lZvhrhenWrRuTJ08GYPXq1axYsaJUr71nzx4aNmxIQUEB2VFcB+jatSuvvvoqQMj5u3bt4sQTTyQlJYXZs2ezqfCHXp06ddizZ89RzxOR0rvllS9CwvzLUb0rPcwhngN98GCYOBFOOQXMAh8nTgy0l5GZMXXqVF577TWaNWvG6aefTmpqKmPGjDnq1950001s27aNNm3aMG7cONq0acOxxx4b9Ws/+OCDdOrUiV69etG8efOjnv/kk08yfvx4OnbsyK5du4LtgwcPZvHixWRlZZGdnR18rvr169O1a1dat27NsGHDij1PRKK38ttdpI+YwVvLtwDw6MC25IztR91Ub9Y/smiHCsysKrAY+NY51z/sserAS0AHYAdwlXMup6Tny8rKcuEbXKxZs4YWLVpEXbyfHDhwgIKCAlJTU9mwYQM9evRg3bp1HHPMMV6XVi7x/G8iUlEOHnQMmvgZn+f8CMBxNVP49M4epKZUrfDXNrMlzrmsSI+VZgx9KLAGqBvhseuBn5xzp5nZIGAccFWpK41j+fn5nH/++RQUFOCc45lnnon7MBeRIy3YsJ3fPbswePzCtVlc0Pw3HlZ0WFSBbmaNgX7AaOC2CKdcCowq/HwK8JSZmUuieXB16tTRlnoiCazgwEF6PvYxm3YEpks3P6kOM249h6pV/LPmUbQ99CeA4UCdYh5vBGwGcM7tN7NdQH1ge9GTzGwIMAQgrRyzUUREKtO7K7dyw8tLg8dTbuhCVvqR05e9dtRAN7P+wA/OuSVm1r240yK0HdE7d85NBCZCYAy9FHWKiFS6X349QLsH32dvwUEAzj29AS9e19G3K5FG00PvClxiZn2BVKCumb3snPt9kXNygSZArplVA44Ffox5tSIileTfC7/hrqmHpx+/99dzOeOk4gYp/OGoge6cuxO4E6Cwh35HWJgDTAf+AHwKXAHMSqbxcxFJHDvzfyXzgQ+CxwM7NOaRgW09rCh6ZZ6HbmYPmNklhYfPA/XNbD2Bi6YjYlGcF6pWrUpmZmbwT05ODosXL+bWW28FYM6cOSxYsCB4/rRp01i9enWpX6e45W4PtUe7NK+IxM5Ts74KCfN5w8+PmzCHUt7675ybA8wp/PzeIu17gYGxLMwrNWrUYNmyZSFt6enpZGUFpn3OmTOH2rVrc/bZZwOBQO/fvz8tW7aMaR3RLs0rIuX33a69dH74o+Dxzec3ZdiF8XezXfzeKVqJ5syZQ//+/cnJyWHChAk8/vjjZGZm8vHHHzN9+nSGDRtGZmYmGzZsYMOGDfTp04cOHTpwzjnnsHZtYLW14pa7LU7RpXknTZrEgAED6NOnD82aNWP48OHB895//326dOlC+/btGThwIHl5eRXzJogkqPveXBkS5kvu7hmXYQ4+Xpzr/rdWsXrL7pg+Z8uT63Lfxa1KPOeXX34JroaYkZHB1KlTg4+lp6dzww03ULt2be644w4ALrnkEvr37x8cHunRowcTJkygWbNmLFy4kJtuuolZs2YFl7u95pprGD9+fKlrX7ZsGV988QXVq1fnjDPO4JZbbqFGjRo89NBDfPjhh9SqVYtx48bx2GOPce+99x79CUWS3IZtefT474+Dx/f2b8n/6ZbhYUXl59tA90qkIZdo5eXlsWDBAgYOPDz6tG/fPqD45W6j1aNHj+DaMC1btmTTpk3s3LmT1atX07VrVwB+/fVXunTpUqbaRZKFc44bX17Ku6u+C7atvP9CaleP/zj07d/gaD1pPzp48CD16tUr9gdCeeauhi+7u3//fpxz9OrVi1deeaXMzyuSTL7M3cklT80PHj85KJNLMxt5WFFsaQy9lMKXoS16XLduXTIyMnjttdeAQE9g+fLlQPHL3ZZH586dmT9/fnA3pfz8fNatWxeT5xZJJAcPOi4bPz8Y5ifWqc5/HuqTUGEOCvRSu/jii5k6dSqZmZnMmzePQYMG8cgjj9CuXTs2bNhAdnY2zz//PG3btqVVq1bBvTqLW+62PBo0aMCkSZP47W9/S5s2bejcuXPwIqyIBPx74TecetdMlm3eCcCk6zry+cieVK9W8SsjVraol8+NtURbPjdR6d9E4lX+r/tpee97weMzGx3LtJu7+moxrbKI1fK5IiJx4absJcxccfii56iLW3Jt1/iewRINBbqIJIztefvIeujDkLaND/f17WJasea7QHfOJc2b73dajkfiSZ8n5rL2u8MTFp4Z3J6LzmzoYUWVz1eBnpqayo4dO6hfv75C3WPOOXbs2EFqaqrXpYiU6OtteVxQ5AYhgJyx/Tyqxlu+CvTGjRuTm5vLtm3bvC5FCPyAbdy4sddliBQrfcSMkOPXb+xCh1P8t/FEZfFVoKekpJCRkfgXLkSkfJZs+pHLn/k0pC1Ze+VF+SrQRUSOJrxX/tHt59G0QeTlqJONAl1E4kL4vp7NTqzNB7ed52FF/qNAFxFfc86RcefMkLZFI3vSoE71Yr4ieSnQRcS3/jV/I/e/dXhHsItan8Qzv+/gYUX+dtRAN7NUYC5QvfD8Kc65+8LOuRZ4BPi2sOkp59xzsS1VRJJFwYGDNBv5Tkjb6gcupOYx6oOWJJp3Zx9wgXMuz8xSgE/M7B3n3Gdh5/2vc+4vsS9RRJLJA2+t5oX5G4PHN5zXlBEXxecOQpXtqKstuoBD+5qlFP7RLYQiElN5+/aTPmJGSJivH31RYoV5djakp0OVKoGPMVpK+5Cofn8xs6rAEuA0YLxzbmGE0y43s3OBdcDfnHObIzzPEGAIQFpaWpmLFpHEcv2kRXy09ofg8YOXtebqzqd4WFEFyM6GIUMgPz9wvGlT4Bhg8OCYvESpls81s3rAVOAW59zKIu31gTzn3D4zuwG40jl3QUnPFWn5XBFJLj/s3stZYz4KaUvYxbTS0wMhHu6UUyAnJ+qnidnyuc65nWY2B+gDrCzSvqPIac8C40rzvCKSfM57ZDabduQHj5+7JoueLX/jYUUV7JtvStdeBtHMcmkAFBSGeQ2gJ2GBbWYNnXNbCw8vAdbErEIRSShffb+HXo/PDWlLitv209Ii99BjOPwczRZ0DYHZZvYlsAj4wDn3tpk9YGaXFJ5zq5mtMrPlwK3AtTGrUEQSRvqIGSFhPu3mrhUb5hV8EbJURo+GmjVD22rWDLTHiK+2oBORxPTZ1zsYNPHwTOfq1arwn4cuqtgXDb8ICYEAnTgxZhchy1TTyJGBYZa0tECYl7KWksbQFegiUqHCF9P6eFh3TqlfqxJeOD0mFyH9RnuKikile2v5Fm555Yvg8ZmNjuWtW7pVXgGVcBHSbxToIhJTkRbTWnpPL46vdUzlFlIJFyH9JpqLoiIiUfmfjzeEhPllmSeTM7Zf5Yc5VMpFSL9RD11Eyu3X/Qc5/e7QxbTWPtiH1JSqHlXE4YuN5bwIGU8U6CJSLndPW8HLnx0el761RzNu63W6hxUVMXhwQgd4OAW6iJTJ7r0FtBn1fkjbhjF9qVolAW/bjxMaQxcpDT/dqOKh3z+3MCTMx11+Jjlj+ynMPaYeuki0KmG1PL/buusXujw8K6QtKW7bjxO6sUgkWgl6o0q0Oo35kO937wseT7quI93PONHDipKTbiwSiYUkvFEFYM3W3Vz05LyQNvXK/UmBLhKtJLxRJfy2/bdv6UbrRsd6VI0cjS6KikQriW5Umb9+e0iYH1sjhZyx/RTmPqceuki0kuRGlfBe+bzh59Pk+JrFnC1+okAXKY0EvlHljaW53DZ5efC4Y/pxvHbD2R5WJKWlQBdJcgcPOk69K3QxreX39ubYmikeVSRlpUAXSWJPzfqKR99fFzy+Mqsx/7iirYcVSXlEs6doKjAXqF54/hTn3H1h51QHXgI6ADuAq5xzOTGvVkRiYm/BAZrf825Im+eLaUm5RdND3wdc4JzLM7MU4BMze8c591mRc64HfnLOnWZmgwhsIn1VBdQrIuU0fMpyJi/ODR7f0ft0/nJBMw8rklg5aqC7wK2keYWHKYV/wm8vvRQYVfj5FOApMzPn1W2oInKEnfm/kvnAByFtX4/pSxWtv5IwohpDN7OqwBLgNGC8c25h2CmNgM0Azrn9ZrYLqA9sD3ueIcAQgLQEvhlDxG/CpyI+flVb/qtdY4+qkYoS1Y1FzrkDzrlMoDFwlpm1Djsl0o/4I3rnzrmJzrks51xWgwYNSl+tiJTK6i27jwjznLH9FOYJqlSzXJxzO81sDtAHWFnkoVygCZBrZtWAY4EfY1WkiJReeJCPfff/MmjnWjhzZ8LOpU920cxyaQAUFIZ5DaAngYueRU0H/gB8ClwBzNL4uYg3Zq39nv8zKXQl05xx/Q8fJNmSv8kkmh56Q+DFwnH0KsBk59zbZvYAsNg5Nx14Hvh/ZraeQM98UIVVLCLFCu+Vvzz7n3T7PHRXIfLzA8sXKNATjtZDF0kAk+ZvZNRbq0Pacsb2C+ysFOn/uBkcPFhJ1UkslbQeulZbFKkolbBdnXOO9BEzQsL8g7+de3i98uJmk2mWWUJSoItUhEPb1W3aFOghH9quLoahfs+0lWTcGboGS87YfjT7TZ3DDUm05K9oyEWkYlTgdnX7DxzktJHvhLQtvrsnJ9SuHvkLsrMTfsnfZKIhF9Fu9ZWtgraru2z8/JAwb1SvBjlj+xUf5hAI75ycwJh5To7CPIFptcVkoN3qK1+Mt6uLdNu+FtOScOqhJ4ORIw+H+SGHpq5JxYjh2HX6iBkhYd6iYV1yxvZTmMsR1ENPBkm6W72nYrBd3fof8uj52MchbVpMS0qiQE8GSbhbvS+UY7u68BuE+rQ6iQlXd4hFVZLAFOjJYPTo0DF00NQ1n5q7bhvXvPB5SFtwTrnIUSjQk0GS7FYf78J75dp4QkpLgZ4sEni3+nj34oIc7pu+KqRNvXIpCwW6iIfCe+UTft+ePq0belSNxDtNW5TE58Obqu5848uIG08ozKU81EOXxOazm6qcc0esv/L2Ld1o3ejYSq9FEo/WcpHEVoFrqpRWnyfmsva7PSFtGiuX0tJaLpK8fHBT1b79B0gfMSMkzD+/q0fpw9yHQ0fiLxpykcTm8U1V4ePkUMZeuc+GjsSfjtpDN7MmZjbbzNaY2SozGxrhnO5mtsvMlhX+ubdiyhUpJY/WA9+et++IMF/7YJ+yD7FoPR6JQjQ99P3A7c65pWZWB1hiZh8451aHnTfPOdc/wteLeMeDm6rCgzzjhFrMvqN7+Z7UB0NH4n9HDXTn3FZga+Hne8xsDdAICA90EX+qpJuqln7zEwOeXhDStvHhvpjFYDEtrccjUSjVRVEzSwfaAQsjPNzFzJab2Ttm1qqYrx9iZovNbPG2bdtKXayIX6WPmBES5pdmnkzO2H6xCXPQVnISlagvippZbeB14K/Oud1hDy8FTnHO5ZlZX2AacMQiFM65icBECExbLHPVIj7x2uLNDJvyZUhbhUxF1Ho8EoWo5qGbWQrwNvCec+6xKM7PAbKcc9uLO0fz0CXehY+VX98tg3v6t/SoGkkWJc1DP2oP3QK/Mz4PrCkuzM3sJOB755wzs7MIDOXsKEfNIr5135srefHT0PFs3SAkfhDNkEtX4GpghZktK2y7C0gDcM5NAK4AbjSz/cAvwCDn1S2oIhUovFf+2JVtGdC+sUfViISKZpbLJ0CJV3acc08BT8WqKBG/6fvkPFZvDb10pF65+I3uFBUpwcGDjlPvCl1Ma9rNXclsUs+jikSKp0AXKUbMbtsXqSQKdJEwP+/bT6v73gtpW3hXD35TN9WjikSio0AXKUK9colnCnQRYPOP+Zzzj9khbWsf7ENqSlWPKhIpPQW6JD31yiVRKNAlaX26YQe/ffazkLaYLaYl4gEFuiSl8F752U3r8+8/dfaoGpHYUKBLUnnp0xzufXNVSJuGVyRRKNAlaYT3ym+54DRu732GR9WIxJ4CXRLeEx+u44kPvwppU69cEpECXRJaeK98/O/a069NQ4+qEalYpdqxSKRCZGdDejpUqRL4mJ1d7qf844uLjwjznLH9FOaS0NRDF29lZ8OQIYd3tN+0KXAMZdqN58BBR9OwxbRm3X4epzaoXd5KRXwvqh2LKoJ2LBIg0COPtPnxKadATk6pnqrdA+/zU35BSJvGyiXRlGvHIpEK9c03pWuPIG/fflqHLaa1/N7eHFszpTyVicQdBbp4Ky0tcg89LS2qL9dt+yKHHfWiqJk1MbPZZrbGzFaZ2dAI55iZ/dPM1pvZl2bWvmLKlYQzejTUrBnaVrNmoL0EuT/lHxHmX42+SGEuSS2aHvp+4Hbn3FIzqwMsMbMPnHOri5xzEdCs8E8n4JnCjyIlO3Thc+TIwDBLWlogzEu4IBoe5GelH8/kG7pUZJUicSGaPUW3AlsLP99jZmuARkDRQL8UeKlwY+jPzKyemTUs/FqRkg0eHNWMliWbfuTyZz4NaVOPXOSwUo2hm1k60A5YGPZQI2BzkePcwraQQDezIcAQgLQox0hF4Mhe+R+7ZXB3/5YeVSPiT1EHupnVBl4H/uqc2x3+cIQvOWI+pHNuIjARAtMWS1GnJKk3luZy2+TlIW3qlYtEFlWgm1kKgTDPds69EeGUXKBJkePGwJbylyfJLLxX/o8r2nBlVpNizhaRowa6BVb7fx5Y45x7rJjTpgN/MbNXCVwM3aXxcymrh99Zw/98/HVIm3rlIkcXTQ+9K3A1sMLMlhW23QWkATjnJgAzgb7AeiAfuC72pUoyCO+VT/5zF87KON6jakTiSzSzXD4h8hh50XMccHOsipLk87tnP2PBhh0hbeqVi5SO7hQVT+0/cJDTRr4T0jZv+Pk0Ob5mMV8hIsVRoItnmo2cScGB0MlO6pWLlJ0CXSrdrl8KaHv/+yFtK0b1pk6qFtMSKQ8FulSq8IuetatXY+X9F3pUjUhiUaBLpfhu1146P/xRSNuGMX2pWqXE6+0iUgoKdKlw4b3y7mc0YNJ1Z3lUjUjiUqBLhVm1ZRf9/vlJSJsueopUHG0SHUsVsNlxvEofMSMkzMddfmZyhLm+B8RD6qHHSow3O45XH635nutfDN0rNimCHPQ9IJ7TJtGxEsPNjuNV+Fh59h870fW0EwoPsku1iUVc0veAVAJtEl0ZYrDZcbz61/yN3P/W6pC2kF55svRck/h7QPxBgR4r5dzsOB4558i4c2ZI24e3nctpJ9YJPXHkyMNhfkh+fqA9kQI9Cb8HxF90UTRWyrjZcby6e9qKI8I8Z2y/I8MckqfnmmTfA+I/CvRYGTwYJk4MjJeaBT5OnFh5PdBKml2x/8BB0kfM4OXPDofx4rt7lnzhs7geaqL1XL3+HpCkp4uiiSB8jBoCPcMYh8nlzyxgyaafgsdNjq/BvOEX+KY+kWRQ0kVRBXoiqODZFXv2FnDmqNDFtNY+2IfUlKrRP0kyzHIRqQQK9ERXpQpE+nc0g4MHy/XU4UvcXtT6JJ75fYdyPaeIlF25pi2a2QtAf+AH51zrCI93B94ENhY2veGce6Ds5UqpVcDsityf8uk2bnZI29dj+lJFi2mJ+FY00xYnAU8BL5VwzjznXP+YVCSlN3p05DHqMs6uCL9B6NYezbit1+nlqVBEKkE0e4rONbP0ii9FyuzQWHQ5x6iXb97JpePnh7QlzW37IgkgVjcWdTGz5cAW4A7n3KpIJ5nZEGAIQFqiTVnz2uDB5brIGN4rf+KqTC5r16i8VYlIJYpFoC8FTnHO5ZlZX2Aa0CzSic65icBECFwUjcFrSzm9u3IrN7y8NKRNvXKR+FTuQHfO7S7y+Uwze9rMTnDObS/vc0vFCu+VT/5zF87KON6jakSkvMod6GZ2EvC9c86Z2VkE7j7dUe7KpMJM+HgDY99ZG9KmXrlI/Itm2uIrQHfgBDPLBe4DUgCccxOAK4AbzWw/8AswyHk1uV1KFGkxrdl3dCfjhFoeVSQisRTNLJffHuXxpwhMaxQfu33ycl5fmhvSpl65SGLR8rkJ7tf9Bzn97ndC2pbd24t6NY/xqCIRqSgK9AR20ZPzWLM1eM2a5ifV4d2/nuthRSJSkbR8bmnFwSbAu/ILSB8xIyTM//NQH4W5SIJTD7004mArtfCpiP/VrhGPX5XpUTUiUpm02mJp+HgT4B/27OWs0R+FtG18uC9mWkxLJJFok+hY8elWaj3+ew4btv0cPB7e5wxu6n6ahxWJiBfiawzd6/Frn22ltv6HPNJHzAgJ85yx/RTmIkkqfnrofhi/jvEyteURPlb++o1n0+GU4yq9DhHxj/jpoY8cGRqkEDgeObLyavDBJsCLcn4MCXOzQK9cYS4i8XNRtAK3WYsX4b1y3bYvknxKuigaPz10n41fV6YZX24NCfPmJ9UhZ2w/hbmIhIifMXQfjV9XlkiLaS2+uycn1K7uUUUi4mfx00P3wfh1ZXpu3tchYd7vzIbkjO2nMBeRYsVPDx3Kvc1aPCg4cJBmI0MX01r9wIXUPCa+/qlEpPIpJXxk1PRVTFqQEzy+qXtThvdp7l1BIhJXFOg+sGdvAWeOej+kbcOYvlStotv2RSR6CnSP/eGFz/l43bbg8Zj/OpPfdUr8mTsiEnvRbEH3AtAf+ME51zrC4wY8CfQF8oFrnXNLw8+TUN/t2kvnh7WYlojETjQ99EkEtph7qZjHLwKaFf7pBDxT+FGK0W3cLHJ/+iV4/PwfsujR4jceViQiiSCaPUXnmll6CadcCrxUuDH0Z2ZWz8waOue2xqjGhLHu+z30fnxuSJv29RSRWInFGHojYHOR49zCtiMC3cyGAEMA0pLgDs+iwm/bf/PmrrRtUs+jakQkEcXixqJIg74RF4hxzk10zmU557IaNGgQg5f2vwUbtoeEea1jqpIztp/CXERiLhY99FygSZHjxsCWGDxv3Avvlc8ddj5p9Wt6VI2IJLpY9NCnA9dYQGdgV7KPn7+57NuQMG/bpB45Y/spzEWkQkUzbfEVoDtwgpnlAvcBKQDOuQnATAJTFtcTmLZ4XUUV63eRFtP64p5eHFfrGI8qEpFkEs0sl98e5XEH3ByziuLUm8u+Zeiry4LHA9o14rGrMj2sSESSje4ULadIi2n956E+VK9W1aOKRCRZKdDLYeLcDYyZuTZ4/MgVbRiY1aSErxARqTgK9DL4ed9+Wt33Xkjb12P6UkWLaYmIhxTopTRlSS53vLY8ePyv6zpy/hkneliRiEiAAj1Ku/cW0KbIErc1Uqqy5sE+HlYkIhJKgR6F8LHyOXd0J10bNIuIzyjQS/DDnr2cNfrwErfXd8vgntmpcxwAAAY1SURBVP4tPaxIRKR4CvRijJ6xmmfnbQwef35XD06sm+phRSIiJVOgh9m042fOe2RO8PjvfZpzY/em3hUkIhIlBXoRQ1/9gjeXHV5XbPl9vTm2RoqHFYmIRE+BDqzasot+//wkePyPK9pwpW4QEpE4k9SB7pxj0MTPWLjxRwDqpFZj0ciepKbotn0RiT9JG+iffb2DQRM/Cx4/e00WvVpqX08RiV9JF+j7Dxyk1+Nz2bj9ZwBOO7E27w49h2pVY7E0vIiId5Iq0N9d+R03vLwkeDz5z104K+N4DysSEYmdpAj0vQUHaP/gB+T/egCArqfV5+XrO2GmxbREJHEkfKD/76Jv+PvrK4LH7ww9hxYN63pYkYhIxYgq0M2sD/AkUBV4zjk3Nuzxa4FHgG8Lm55yzj0XwzpLbVd+AW0fOLyY1oD2jXjsSu0gJCKJK5o9RasC44FeQC6wyMymO+dWh536v865v1RAjaU2fvZ6HnnvP8HjecPPp8nx2qBZRBJbND30s4D1zrmvAczsVeBSIDzQPff97r10GnN4Ma0bzmvKiIuae1iRiEjliSbQGwGbixznAp0inHe5mZ0LrAP+5pzbHH6CmQ0BhgCkpaWVvtoSjJq+ikkLcoLHi0b2pEGd6jF9DRERP4sm0CNNBXFhx28Brzjn9pnZDcCLwAVHfJFzE4GJAFlZWeHPUSYbt//M+Y/OCR7f3a8Ffzzn1Fg8tYhIXIkm0HOBogubNAa2FD3BObejyOGzwLjyl1Yy5xx/+fcXzFixNdi2YlRv6qRqMS0RSU7RBPoioJmZZRCYxTII+F3RE8ysoXPuULJeAqyJaZVhVuTu4uKnDi+m9diVbRnQvnFFvqSIiO8dNdCdc/vN7C/AewSmLb7gnFtlZg8Ai51z04FbzewSYD/wI3BtRRW8+cf8YJjXr3UM80dcoMW0RESIch66c24mMDOs7d4in98J3Bnb0iKrXb0aXU+rz/XdMriguRbTEhE5JO7uFD2u1jFk/7Gz12WIiPiOlhgUEUkQCnQRkQShQBcRSRAKdBGRBKFAFxFJEAp0EZEEoUAXEUkQCnQRkQRhzsVk0cPSv7DZNmBTFKeeAGyv4HLikd6X4um9iUzvS/Hi6b05xTnXINIDngV6tMxssXMuy+s6/EbvS/H03kSm96V4ifLeaMhFRCRBKNBFRBJEPAT6RK8L8Cm9L8XTexOZ3pfiJcR74/sxdBERiU489NBFRCQKCnQRkQThy0A3syZmNtvM1pjZKjMb6nVNfmJmVc3sCzN72+ta/MTM6pnZFDNbW/i908XrmvzCzP5W+H9ppZm9YmapXtfkFTN7wcx+MLOVRdqON7MPzOyrwo/HeVljWfky0AnsTXq7c64F0Bm42cxaelyTnwylgjfijlNPAu8655oDbdF7BICZNQJuBbKcc60J7A08yNuqPDUJ6BPWNgL4yDnXDPio8Dju+DLQnXNbnXNLCz/fQ+A/ZiNvq/IHM2sM9AOe87oWPzGzusC5wPMAzrlfnXM7va3KV6oBNcysGlAT2OJxPZ5xzs0lsJl9UZcCLxZ+/iJwWaUWFSO+DPSizCwdaAcs9LYS33gCGA4c9LoQnzkV2Ab8q3A46jkzq+V1UX7gnPsWeBT4BtgK7HLOve9tVb7zG+fcVgh0KIETPa6nTHwd6GZWG3gd+KtzbrfX9XjNzPoDPzjnlnhdiw9VA9oDzzjn2gE/E6e/Nsda4XjwpUAGcDJQy8x+721VUhF8G+hmlkIgzLOdc294XY9PdAUuMbMc4FXgAjN72duSfCMXyHXOHfpNbgqBgBfoCWx0zm1zzhUAbwBne1yT33xvZg0BCj/+4HE9ZeLLQDczIzAWusY595jX9fiFc+5O51xj51w6gYtas5xz6mkBzrnvgM1mdkZhUw9gtYcl+ck3QGczq1n4f6sHumAcbjrwh8LP/wC86WEtZVbN6wKK0RW4GlhhZssK2+5yzs30sCbxv1uAbDM7BvgauM7jenzBObfQzKYASwnMIPuCBLnVvSzM7BWgO3CCmeUC9wFjgclmdj2BH4ADvauw7HTrv4hIgvDlkIuIiJSeAl1EJEEo0EVEEoQCXUQkQSjQRUQShAJdRCRBKNBFRBLE/wdZTKJoidp+mwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXxU9b3/8ddnZrIQkhAICYQ1LJFFNEED2Cog4ILLBaxS7WJxaWnrba+292HVn7feh/b21t7bq95H21/78LZW/V2qWBek1i6IVcQFCRgUBGURMBAgJIYlIdvM9/fHDBggmMkyGc7M+/l4zGPmfOc7cz5HxvecfOec8zXnHCIi4j2+eBcgIiKdowAXEfEoBbiIiEcpwEVEPEoBLiLiUYGeXFn//v1dYWFhT65SRMTz1qxZs985l3die48GeGFhIWVlZT25ShERzzOzHW21awhFRMSjFOAiIh7VboCb2RgzK291O2hmt5lZPzNbZmabI/d9e6JgEREJa3cM3Dn3AVACYGZ+YBfwHHAnsNw5d7+Z3RlZviOGtYp4XnNzMxUVFTQ0NMS7FDkNpaenM2TIEFJSUqLq39EfMWcBW51zO8xsLnBhpP0x4BUU4CKfqaKigqysLAoLCzGzeJcjpxHnHNXV1VRUVDBixIioXtPRMfDrgCcijwc45yojK64E8tt6gZktNLMyMyurqqrq4OpEEktDQwO5ubkKbzmJmZGbm9uhv86iDnAzSwXmAH/oSFHOuYedc6XOudK8vJMOYxRJOgpvOZWOfjY6sgd+GbDWObc3srzXzAoiKy0A9nVozR3wfPku/vetNg+DFBFJWh0J8C/x6fAJwFJgQeTxAuD57irqRC++V8kjr38Uq7cXSSqZmZkxed+qqiqmTJnCxIkTee2112KyDi+74YYbePrpp7v1PaMKcDPLAC4Gnm3VfD9wsZltjjx3f7dW1sqovEx2VtfTHAzFahUi0kXLly9n7NixvPPOO0ydOjWq1wSDwRhXBS0tLTFfR7xEFeDOuXrnXK5z7kCrtmrn3CznXFHkviZWRY7Oz6Ql5NhRXRerVYgkHecct99+OxMmTOCss85i8eLFAFRWVjJt2jRKSkqYMGECr732GsFgkBtuuOFY3wcffPC49yovL+cHP/gBL774IiUlJRw5coQnnniCs846iwkTJnDHHZ8eoJaZmck999zDlClTePPNN497n61btzJ79mzOPfdcpk6dyqZNmzhw4ACFhYWEQuEduPr6eoYOHUpzc3Ob/SG8t/v973+fGTNmcPvtt1NUVMTRgyhCoRCjR49m//79x627rq6Om266iUmTJjFx4kSefz48qPDoo48yd+5cZs+ezZgxY7j33nuPveaBBx5gwoQJTJgwgYceeuhY++OPP87ZZ59NcXEx119//bH2FStW8PnPf56RI0d2y954j14LpbNG5YX/5Nuyr47R+Vlxrkake9z7xw28v/tgt77n+EHZ/Os/nBlV32effZby8nLWrVvH/v37mTRpEtOmTeP3v/89l156KXfffTfBYJD6+nrKy8vZtWsX69evB6C2tva49yopKeG+++6jrKyMX/ziF+zevZs77riDNWvW0LdvXy655BKWLFnCvHnzqKurY8KECdx3330n1bRw4UJ+/etfU1RUxKpVq7jlllt4+eWXKS4u5tVXX2XGjBn88Y9/5NJLLyUlJeWU/QE+/PBDXnrpJfx+Pzk5OSxatIjbbruNl156ieLiYvr373/cun/84x8zc+ZMHnnkEWpra5k8eTIXXXQRAG+//Tbr168nIyODSZMmccUVV2Bm/O53v2PVqlU455gyZQrTp08nNTWVH//4x7z++uv079+fmppP920rKytZuXIlmzZtYs6cOVxzzTXR/+O2wRMBPjKvNwBbqw7HuRKRxLFy5Uq+9KUv4ff7GTBgANOnT2f16tVMmjSJm266iebmZubNm0dJSQkjR45k27ZtfPe73+WKK67gkksu+cz3Xr16NRdeeCFHjzz7yle+wooVK5g3bx5+v5+rr776pNccPnyYN954g/nz5x9ra2xsBODaa69l8eLFzJgxgyeffJJbbrnlM/sDzJ8/H7/fD8BNN93E3Llzue2223jkkUe48cYbT1r/3/72N5YuXcrPfvYzIHzI586dOwG4+OKLyc3NBeALX/gCK1euxMy46qqr6N2797H21157DTPjmmuuOfYF0a9fv2PrmDdvHj6fj/Hjx7N37166yhMBnpWewsDsdLbuU4BL4oh2TzlWTjWh+bRp01ixYgV/+tOfuP7667n99tv52te+xrp16/jrX//KL3/5S5566ikeeeSRDr83hM82PBqsrYVCIXJycigvLz/puTlz5nDXXXdRU1PDmjVrmDlzJnV1dafsDxwLVoChQ4cyYMAAXn75ZVatWsWiRYvarPmZZ55hzJgxx7WvWrXqpMP7zOyU2+icO+XhgGlpacf16yrPXMxqVH5v7YGLdKNp06axePFigsEgVVVVrFixgsmTJ7Njxw7y8/P5xje+wc0338zatWvZv38/oVCIq6++mh/96EesXbv2M997ypQpvPrqq+zfv59gMMgTTzzB9OnTP/M12dnZjBgxgj/8IXyqiXOOdevWAeFx88mTJ3Prrbdy5ZVX4vf7P7N/W77+9a/z1a9+lS9+8YttfoFceuml/PznPz8WrO+8886x55YtW0ZNTQ1HjhxhyZIlnH/++UybNo0lS5ZQX19PXV0dzz33HFOnTmXWrFk89dRTVFdXAxw3hNLdPLEHDjA6L5Nn1u76zG83EYneVVddxZtvvklxcTFmxn/8x38wcOBAHnvsMf7zP/+TlJQUMjMzefzxx9m1axc33njjsR8Sf/KTn3zmexcUFPCTn/yEGTNm4Jzj8ssvZ+7cue3WtGjRIr797W/zb//2bzQ3N3PddddRXFwMhIdR5s+fzyuvvBJV/xPNmTOHG2+8sc3hE4Af/vCH3HbbbZx99tk45ygsLOSFF14A4IILLuD6669ny5YtfPnLX6a0tBQI/1g6efJkIPwFMXHiRADuvvtupk+fjt/vZ+LEiTz66KPtbntnWHfsxkertLTUdXZCh8ff3M49z2/grbtmMbBPevcWJtJDNm7cyLhx4+JdRlIqKyvje9/7XoePUX/00UeP/TjbE9r6jJjZGudc6Yl9PTOEMjpyJIqGUUSko+6//36uvvrqdv9y8BrPBPio/KOHEirARaRj7rzzTnbs2MEFF1zQ4dfecMMNPbb33VGeCfD8rDQy0wLaAxfP68lhS/GWjn42PBPgZsao/EztgYunpaenU11drRCXkxy9Hnh6evS/8XnmKBSAUXm9eX3L/vY7ipymhgwZQkVFBbo2vrTl6Iw80fJUgI/Oz+TZtbs41NBMVnp0Uw6JnE5SUlKinm1FpD2eGUKBT6+JsrVKF7USEfFmgGscXETEWwE+PDeDgM/YoiNRRES8FeApfh/DczO0By4igscCHMI/ZGoPXETEgwGu6dVERMI8GeDh6dXq412KiEhcRTupcY6ZPW1mm8xso5l9zsz6mdkyM9scue8b62IhPIQCuiaKiEi0e+D/DfzFOTcWKAY2AncCy51zRcDyyHLMaXo1EZGwdgPczLKBacBvAZxzTc65WmAu8Fik22PAvFgV2ZqmVxMRCYtmD3wkUAX8zszeMbPfmFlvYIBzrhIgcp8fwzqPo+nVRESiC/AAcA7wK+fcRKCODgyXmNlCMyszs7LuuoDP6LxMtlbV6YpuIpLUognwCqDCObcqsvw04UDfa2YFAJH7fW292Dn3sHOu1DlXmpeX1x01Myo/k8ONLew92Ngt7yci4kXtBrhzbg/wsZmNiTTNAt4HlgILIm0LgOdjUmEbRml6NRGRqC8n+11gkZmlAtuAGwmH/1NmdjOwE5gfmxJP1vpQwvNH9++p1YqInFaiCnDnXDlw0ozIhPfGe5ymVxMR8eCZmPDp9GoKcBFJZp4McAhPr6azMUUkmXk2wMcMyGLvwUZq6priXYqISFx4NsDHD8oGYGPlwThXIiISH94N8IJwgG/YfSDOlYiIxIdnAzw3M42B2em8v1t74CKSnDwb4ABnDspmgwJcRJKUpwN8/KBstlYdpqE5GO9SRER6nKcD/MxB2YQcfLDnULxLERHpcZ4O8PEFfQA0jCIiScnTAT6kby+y0gK8X6kjUUQk+Xg6wH0+Y9ygbB2JIiJJydMBDuHjwTdWHiIY0uQOIpJcPB/gZw7K5khzkO3VdfEuRUSkR3k+wI+eUq8fMkUk2Xg+wIvys0jxm8bBRSTpeD7AUwM+ivKzeF8XtRKRJOP5AIfwMMr7uw9olnoRSSoJEeBnDspm/+Emqg5plnoRSR4JEeCfXlpWwygikjyiCnAz225m75lZuZmVRdr6mdkyM9scue8b21JPbVzkSBSNg4tIMunIHvgM51yJc+7o7PR3Asudc0XA8shyXGSnpzCsX4aORBGRpNKVIZS5wGORx48B87peTueFrw2ua6KISPKINsAd8DczW2NmCyNtA5xzlQCR+/xYFBit8QXZbK+u53BjSzzLEBHpMYEo+53vnNttZvnAMjPbFO0KIoG/EGDYsGGdKDE6rSc5nlTYL2brERE5XUS1B+6c2x253wc8B0wG9ppZAUDkft8pXvuwc67UOVeal5fXPVW34cxB4WuDaxxcRJJFuwFuZr3NLOvoY+ASYD2wFFgQ6bYAeD5WRUZjQHYa/XqnKsBFJGlEM4QyAHjOzI72/71z7i9mthp4ysxuBnYC82NXZvvMLPxDpiZ3EJEk0W6AO+e2AcVttFcDs2JRVGeNL8jmd69vpzkYIsWfEOcoiYicUkKl3JmD+9AUDLGpUpMci0jiS6gAP3d4+GTQtTs/iXMlIiKxl1ABPqhPOgOz01mzQwEuIokvoQLczDh3eF8FuIgkhYQKcIBzhvdlV+0R9hxoiHcpIiIxlXABrnFwEUkWCRfg4wuySQv4NIwiIgkv4QI8NeCjeEiOAlxEEl7CBTiEx8E37D5AQ3Mw3qWIiMRMQgb4ucP70hx0vLdLp9WLSOJK2AAHNIwiIgktIQO8X+9URvbvrQAXkYSWkAEO4XHwtTs+wTkX71JERGIiYQP83OF9qa5rYkd1fbxLERGJiYQOcNA4uIgkroQN8NF5mWSlB1ijMzJFJEElbID7fMY5w8Lj4CIiiShhAxzCwygf7D3EwYbmeJciItLtEj7AnYPynbXxLkVEpNsldIAXD83BZ/ohU0QSU9QBbmZ+M3vHzF6ILI8ws1VmttnMFptZauzK7JzMtABjB2br0rIikpA6sgd+K7Cx1fJPgQedc0XAJ8DN3VlYdzl3eF/e2VlLMKQTekQksUQV4GY2BLgC+E1k2YCZwNORLo8B82JRYFedO7wvhxtb+GCPZqoXkcQS7R74Q8APgFBkOReodc61RJYrgMFtvdDMFppZmZmVVVVVdanYzpg8oh8Ab26r7vF1i4jEUrsBbmZXAvucc2taN7fRtc0xCufcw865UudcaV5eXifL7LxBOb0Ymdeb1zb3/JeHiEgsBaLocz4wx8wuB9KBbMJ75DlmFojshQ8BdseuzK6ZVpTHk6t30tgSJC3gj3c5IiLdot09cOfcXc65Ic65QuA64GXn3FeAvwPXRLotAJ6PWZVddMHo/jQ0h3Q4oYgklK4cB34H8H0z20J4TPy33VNS9ztvVC4Bn/Ha5v3xLkVEpNt0KMCdc684566MPN7mnJvsnBvtnJvvnGuMTYldl5kW4JxhfVmpABeRBJLQZ2K2NrWoP+t3H6CmrinepYiIdIukCfALivrjHLy+RXvhIpIYkibAzx6SQ3Z6QIcTikjCSJoA9/uM80f3Z+Xm/ZonU0QSQtIEOMDUojx2H2hga1VdvEsREemyJAvw/gCs1DCKiCSApArwof0yKMzN0PHgIpIQkirAIXw0ylvbqmlqCbXfWUTkNJZ0AT61KI+6piDvaJIHEfG4pAvwz43Kxe8zVup4cBHxuKQL8Oz0FEqG5rBC4+Ai4nFJF+AQvjrhuxW11NbrtHoR8a6kDPBpZxw9rV6z9IiIdyVlgBcPySEnI4Vl7++JdykiIp2WlAEe8Pu4ZPwAlm/cR2NLMN7liIh0SlIGOMBlEwo41NjCGxpGERGPStoA//zoXLLSArz4XmW8SxER6ZSkDfC0gJ9Z4/JZtnEvzUGdlSki3pO0AQ5w2VkF1NY3s2pbTbxLERHpsKQO8Oln5JGR6ufP6zWMIiLe026Am1m6mb1tZuvMbIOZ3RtpH2Fmq8xss5ktNrPU2JfbvdJT/MwYk89fN+whGNIkDyLiLdHsgTcCM51zxUAJMNvMzgN+CjzonCsCPgFujl2ZsTN7wkD2H26ibLuGUUTEW9oNcBd2OLKYErk5YCbwdKT9MWBeTCqMsRlj80kN+Pjzep3UIyLeEtUYuJn5zawc2AcsA7YCtc65lkiXCmDwKV670MzKzKysqur0mwknMy3A9DPy+OuGPYQ0jCIiHhJVgDvngs65EmAIMBkY11a3U7z2YedcqXOuNC8vr/OVxtBlEwZSeaCBdRW18S5FRCRqHToKxTlXC7wCnAfkmFkg8tQQYHf3ltZzZo0bQIrfNIwiIp4SzVEoeWaWE3ncC7gI2Aj8Hbgm0m0B8Hysioy1Pr1S+Pyo/vx5fSXOaRhFRLwhmj3wAuDvZvYusBpY5px7AbgD+L6ZbQFygd/GrszYu2zCQD6uOcKG3QfjXYqISFQC7XVwzr0LTGyjfRvh8fCEcMmZA7l7yXpefK+SCYP7xLscEZF2JfWZmK31653K1KL+PLt2l07qERFPUIC3ct2koew52MCKD0+/wx1FRE6kAG9l5tgB9M9M5cnVO+NdiohIuxTgraQGfHzhnCEs37iPfYca4l2OiMhnUoCf4IulQ2kJOZ5duyvepYiIfCYF+AlG52cyqbAvT63+WMeEi8hpTQHehmsnDWPb/jpWb/8k3qWIiJySArwNl581kMy0gH7MFJHTmgK8DRmpAeaUDOLF9yo52NAc73JERNqkAD+F6yYNpaE5xNJyz16jS0QSnAL8FM4a3IdxBdksXv1xvEsREWmTAvwUzIxrS4fw3q4DbNh9IN7liIicRAH+GeZNHExqwKe9cBE5LSnAP0NORipXnl3A02sqqK1vinc5IiLHUYC345vTRlHfFOTxN3fEuxQRkeMowNsxZmAWs8bm8+gb2znSFIx3OSIixyjAo/CtC0dRU9fEU2UaCxeR04cCPAqTCvtROrwvD6/YRnMwFO9yREQABXjUvjV9FLtqj/CndyvjXYqICKAAj9rMsfkU5Wfy61e36iqFInJaaDfAzWyomf3dzDaa2QYzuzXS3s/MlpnZ5sh939iXGz8+n/Gt6aPYtOcQr3ygKddEJP6i2QNvAf7ZOTcOOA/4RzMbD9wJLHfOFQHLI8sJbU7JIAb1SedXr2yNdykiIu0HuHOu0jm3NvL4ELARGAzMBR6LdHsMmBerIk8XKX4fX586kre317BmR028yxGRJNehMXAzKwQmAquAAc65SgiHPJB/itcsNLMyMyurqvL+0MN1k4eSk5GivXARibuoA9zMMoFngNuccwejfZ1z7mHnXKlzrjQvL68zNZ5WMlID3HT+CF7auE974SISV1EFuJmlEA7vRc65ZyPNe82sIPJ8AbAvNiWefr4+dQQDstP40QsbdUSKiMRNNEehGPBbYKNz7oFWTy0FFkQeLwCe7/7yTk8ZqQH++ZIxlH9cyx91XLiIxEk0e+DnA9cDM82sPHK7HLgfuNjMNgMXR5aTxtXnDGFcQTY//fMmGpp1jRQR6XnRHIWy0jlnzrmznXMlkduLzrlq59ws51xR5D6pBoT9PuNfrhjHrtojPPrG9niXIyJJSGdidsH5o/sza2w+v3x5C9WHG+NdjogkGQV4F911+Tjqm4M89NLmeJciIklGAd5Fo/Mz+fLkYfz+7Z1s2Xco3uWISBJRgHeD2y4qIiPFz7+/uCnepYhIElGAd4PczDS+M3M0L2/ax5/f02GFItIzFODd5KYLRjBhcDb/smS9ftAUkR6hAO8mKX4fP5tfzMGGZu5ZuiHe5YhIElCAd6OxA7O57aIz+NO7lbzw7u54lyMiCU4B3s2+OW0kZw/pwz3Pb2C/hlJEJIYU4N0s4PfxX/OLOdzQwg+XrNfFrkQkZhTgMVA0IIvvXXwGf16/Rxe7EpGYUYDHyDemjqB4aA73PL+ePQca4l2OiCQgBXiMHB1KaWoJ8e1Fa2hs0RULRaR7KcBjaHR+Jv81v5h3dtZyz5INGg8XkW6lAI+xy84q4DszRrO47GP+d9XOeJcjIglEAd4DvnfxGcwYk8e9Szfw9kdJddl0EYkhBXgP8PuMh66byNB+GdyyaA2VB47EuyQRSQAK8B7Sp1cK//O1c2loDvGt/7dG07CJSJcpwHvQ6PwsHvhiMesqDvDdJ96hORiKd0ki4mEK8B52yZkDuW/umSx7fy/fW1xOMKQjU0Skc9oNcDN7xMz2mdn6Vm39zGyZmW2O3PeNbZmJ5WufK+T/XD6WF96t5I5n3iWkEBeRTohmD/xRYPYJbXcCy51zRcDyyLJ0wMJpo/jeRWfw9JoK7lmqa6aISMcF2uvgnFthZoUnNM8FLow8fgx4BbijG+tKCv80azRHmoP8+tWtpAf83H3FOMws3mWJiEe0G+CnMMA5VwngnKs0s/xTdTSzhcBCgGHDhnVydYnJzLhj9hgamoP8ZuVHOODuy8fh8ynERaR9nQ3wqDnnHgYeBigtLdU4wQnMjH/9h/EA/HblR+yuPcKD15aQnuKPc2Uicrrr7FEoe82sACByv6/7Sko+R0P8h1eO5y8b9vCl/3lL82qKSLs6G+BLgQWRxwuA57unnORlZtx8wQh+9ZVzeH/3Qa76v2+wrepwvMsSkdNYNIcRPgG8CYwxswozuxm4H7jYzDYDF0eWpRvMnlDAkwvPo66xhS/86g3e2lYd75JE5DRlPXn4WmlpqSsrK+ux9XnZzup6bnj0bbbvr+OfZhXxnRmjCfh13pVIMjKzNc650hPblQinqWG5GSz9zgXMKxnMQy9t5kv/8xa7anURLBH5lAL8NJaZFuCBa0t48Npi3t99kMseWsGf39McmyISpgD3gKsmDuHFW6cyon9vvr1oLT94eh2f1DXFuywRiTMFuEcMz+3NH771eb594SieWbuLmf/1CotX79R1VESSmALcQ1IDPu6YPZY//dMFFOVncccz73H1r99g/a4D8S5NROJAAe5BYwdms/ib5/HAF4v5uKaeOb9YyQ+XrGffwYZ4lyYiPUiHEXrcgSPNPPC3D/jfVTsJ+Iyvnjecb00fRV5WWrxLE5FucqrDCBXgCWJHdR0/f3kLz66tIDXg4/rzhvPN6aPon6kgF/E6BXiS+Gh/HT9fvpkl5btIDfi4auIQvva54YwryI53aSLSSQrwJLO16jAPv7qNJeW7aGwJMXlEPxZ8rpBLzhxAis7oFPEUBXiSqq1v4g9lFTz+1nY+rjnCgOw05k0czNziwYwryNIEEiIeoABPcsGQ49UP97HorZ28+mEVLSFHUX4mc0sGMad4MMNyM+JdooicggJcjqmpa+LF9ypZWr6bt7fXADC+IJuZY/OZMTafkqE5+DUrkMhpQwEubdpVe4QX1u1m+aZ9rNnxCcGQo1/vVKafkce0M/ozeUQug3N6xbtMkaSmAJd2Hahv5tXNVfx90z5e+WAfn9Q3AzCkby8mj+jHeSNyObewLyNye2veTpEepACXDgmGHJv2HOTtj2pYta2Gt7fXUBO5gFZWeoCzBvfhrCF9OHtwDmcN7sOQvr0U6iIxogCXLnHOsXnfYcp31rKuopb3dh1gY+VBmoPhz09Gqp+i/EyKBmQxZkAWRQMyGdG/N4NzemkiCpEuUoBLt2tsCfLhnsOs332AD/ceitwOU3Xo0wmZAz5jaL8MCnMzGJ7bmyF9ezE4pxeDcnoxuG8vcnun6lBGkXacKsAD8ShGEkNawM9ZQ8JDKa3V1DWxee8hdlTXs726ju3VdXy0v55VH9VQ3xQ84T18DMhOJz8rjQHZ6eRlpZGfnUb/zDT6Z6bSr3caub1T6dc7lYxUv8JepBUFuHS7fr1TmTIylykjc49rd85x4Egzu2qPsOuTI+yuPcLuAw3sPRi+bdxzkFc/bORwY0ub75sa8JHTK4WcjBRyeqXSJyOF7PQUstIDZKcHyEpPITM9QFZ6gN6pATJS/fROC4RvqX7SU/1kpPg1pCMJo0sBbmazgf8G/MBvnHOanV5OyczIyUglJyOVMwf1OWW/usYWqg83UV3XSE1dE9V1TdTUNfFJfRMH6puprW+m9kgTH9fUc/BIM4caWzjc2EK0o4EpfiM9xU+vFD/pKX7SU3zh+4CftBQfaQEfqQEfqf7wfVrAT2rAR4rfR6rfSPH7SAn4CPjCjwN+I8UXvg/4w+1+n5HiN/w+H34LL7e+BXyGL9LuM/D57Fg/X6TNb4a16mMY5gOfRV5jhkXaj1vWXylJo9MBbmZ+4JfAxUAFsNrMljrn3u+u4iQ5Hd1r7sjZoaGQo64pHOSHGlqoa2yhvilIXWNLpD1IY3OQ+qYgR5qDHGkK3xpagjQ2h2hoCdLQHORQQwvVLSGagiGaWkI0tgRpagnRHHTH2rzAjgZ65LG1fszR4P+0ncgyR9taPR95Gjva8dj7fLqu8DN23PKnvU/9pXLstce95uT3OfH9Tn6fk59ps2+U320d+QqM9gvzkQWTuv2M567sgU8GtjjntgGY2ZPAXEABLj3O5zOy0lPISk+h4NQ7913mnCMYcscCvSUYoiXkaA6GaAk6WkLh5ZZguF9LKHx/7OYcwVC4b8g5giEIuaOPwzfnIBhpC4UcIXe0T3j9R593kfYT21y40HB/jvYLP+bYa8L9XKs+rbfx6HPw6fPu2PPhVlr3adUXTmzjpLbj+h7X9mkNJ/23P+W/SRttbfaL7k+0Dh3W0YHOqYHuH7rrSoAPBj5utVwBTDmxk5ktBBYCDBs2rAurE4k/M4sMlUAv/PEuR5JcV74S2vq74aTvI+fcw865UudcaV5eXhdWJyIirXUlwHlc7UkAAARRSURBVCuAoa2WhwC7u1aOiIhEqysBvhooMrMRZpYKXAcs7Z6yRESkPZ0eA3fOtZjZd4C/Ej6M8BHn3IZuq0xERD5Tl44Dd869CLzYTbWIiEgH6JQ0ERGPUoCLiHiUAlxExKN69HKyZlYF7Ojky/sD+7uxnHhLpO1JpG2BxNqeRNoWSN7tGe6cO+lEmh4N8K4ws7K2rofrVYm0PYm0LZBY25NI2wLanhNpCEVExKMU4CIiHuWlAH843gV0s0TankTaFkis7UmkbQFtz3E8MwYuIiLH89IeuIiItKIAFxHxKE8EuJnNNrMPzGyLmd0Z73o6wsweMbN9Zra+VVs/M1tmZpsj933jWWNHmNlQM/u7mW00sw1mdmuk3XPbZGbpZva2ma2LbMu9kfYRZrYqsi2LI1fb9Awz85vZO2b2QmTZk9tjZtvN7D0zKzezskib5z5nR5lZjpk9bWabIv//fK6r23PaB3iruTcvA8YDXzKz8fGtqkMeBWaf0HYnsNw5VwQsjyx7RQvwz865ccB5wD9G/j28uE2NwEznXDFQAsw2s/OAnwIPRrblE+DmONbYGbcCG1ste3l7ZjjnSlodK+3Fz9lR/w38xTk3Figm/G/Ute1xx+bUOz1vwOeAv7Zavgu4K951dXAbCoH1rZY/AAoijwuAD+JdYxe27XnCE1t7epuADGAt4WkB9wOBSPtxn7/T/UZ4YpXlwEzgBcIzZ3lye4DtQP8T2jz5OQOygY+IHDjSXdtz2u+B0/bcm4PjVEt3GeCcqwSI3OfHuZ5OMbNCYCKwCo9uU2S4oRzYBywDtgK1zrmWSBevfd4eAn4AhCLLuXh3exzwNzNbE5lbFzz6OQNGAlXA7yLDW78xs950cXu8EOBRzb0pPcvMMoFngNuccwfjXU9nOeeCzrkSwnuuk4FxbXXr2ao6x8yuBPY559a0bm6jqye2BzjfOXcO4eHTfzSzafEuqAsCwDnAr5xzE4E6umH4xwsBnohzb+41swKAyP2+ONfTIWaWQji8Fznnno00e3qbnHO1wCuEx/VzzOzoZCde+rydD8wxs+3Ak4SHUR7Co9vjnNsdud8HPEf4C9arn7MKoMI5tyqy/DThQO/S9nghwBNx7s2lwILI4wWEx5E9wcwM+C2w0Tn3QKunPLdNZpZnZjmRx72Aiwj/sPR34JpIN09sC4Bz7i7n3BDnXCHh/09eds59BQ9uj5n1NrOso4+BS4D1ePBzBuCc2wN8bGZjIk2zgPfp6vbEe3A/yh8ALgc+JDw+eXe86+lg7U8AlUAz4W/hmwmPSy4HNkfu+8W7zg5szwWE/wR/FyiP3C734jYBZwPvRLZlPXBPpH0k8DawBfgDkBbvWjuxbRcCL3h1eyI1r4vcNhz9/96Ln7NW21QClEU+b0uAvl3dHp1KLyLiUV4YQhERkTYowEVEPEoBLiLiUQpwERGPUoCLiHiUAlxExKMU4CIiHvX/Aa4PW4PTBecpAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 第五步：结果展示。画出原y与x的曲线与网络结构拟合后的曲线\n",
    "predicted = model(torch.from_numpy(x_train)).detach().numpy()  # 模型输出结果\n",
    "\n",
    "plt.plot(x_train, y_train, 'ro', label='Original data')  # 原始数据\n",
    "plt.plot(x_train, predicted, label='Fitted line')  # 拟合之后的直线\n",
    "plt.legend()\n",
    "plt.show()\n",
    "\n",
    "# 画loss在迭代过程中的变化情况\n",
    "plt.plot(loss_history, label='loss for every epoch')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "toc-hr-collapsed": true
   },
   "source": [
    "## 常见的指标"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 分类指标"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "predict = output.argmax(dim = 1)\n",
    "confusion_matrix =torch.zeros(2,2)\n",
    "for p,t in zip(predict.view(-1), target.view(-1)):\n",
    "    confusion_matrix[t.long(), p.long()] += 1\n",
    "a_p =(confusion_matrix.diag() / confusion_matrix.sum(1))[0]\n",
    "b_p = (confusion_matrix.diag() / confusion_matrix.sum(1))[1]\n",
    "a_r =(confusion_matrix.diag() / confusion_matrix.sum(0))[0]\n",
    "b_r = (confusion_matrix.diag() / confusion_matrix.sum(0))[1]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 语义分割MIOU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mIoU_metrics(pred, mask):\n",
    "    \"\"\"\n",
    "\n",
    "    :param pred: (N,C,H,W)\n",
    "    :param mask:  (N,C,H,W)\n",
    "    :return: IOU (C)\n",
    "    \"\"\"\n",
    "    N, C, H, W = mask.size()\n",
    "    eps=1e-6\n",
    "    input_flat = pred.permute(0, 2, 3, 1).reshape(N * H * W, -1)\n",
    "    target_flat = mask.permute(0, 2, 3, 1).reshape(N * H * W, -1)\n",
    "    intersection = input_flat * target_flat\n",
    "    iou = (intersection.sum(0) + eps) / (input_flat.sum(0) + target_flat.sum(0)-intersection.sum(0) + eps)\n",
    "    return iou\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 常用的优化器\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.SGD(model.parameters(),\n",
    "                                 lr=lr,\n",
    "                                 momentum=momentum,\n",
    "                                 weight_decay=weight_decay)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 常用的学习率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# lr scheduling\n",
    "lr_scheduler = WarmupPolyLR(optimizer,\n",
    "                                 max_iters=epochs * iters_per_epoch,\n",
    "                                 power=0.9,\n",
    "                                 warmup_factor=warmup_factor,\n",
    "                                 warmup_iters=warmup_iters,\n",
    "                                 warmup_method=warmup_method)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scheduler = torch.optim.lr_scheduler.MultiStepLR(optimiser, milestones = [10,20], gamma = 0.1)\n",
    "scheduler.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "toc-hr-collapsed": true
   },
   "source": [
    "## 训练过程"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def training_loop(weight_dir, results_dir, n_epochs, optimizer, model, loss_fn, trian_dataloader, val_dataloader, model_name, save_best=True, resume = False):\n",
    "    from tqdm import tqdm\n",
    "    from collections import deque\n",
    "    import numpy as np\n",
    "    import pandas as pd\n",
    "    \n",
    "    min_loss = 1e10\n",
    "    \n",
    "    train_history = []\n",
    "    if resume:\n",
    "        model.load_state_dict(torch.load(\"./%s/%s.pt\" % (weight_dir, model_name)))\n",
    "        pf = pd.read_csv(\"./%s/%s_history.csv\" % (results_dir, model_name))\n",
    "        train_history = pf.values.tolist()\n",
    "    \n",
    "    for epoch in range(1, n_epochs + 1):\n",
    "    \n",
    "        model.train()\n",
    "        train_loss = deque(maxlen=20)\n",
    "        train_bar = tqdm(enumerate(trian_dataloader))\n",
    "        for index, (image, label) in train_bar:\n",
    "            image = image.to(device)\n",
    "            label = label.to(device)\n",
    "            model = model.to(device)\n",
    "    \n",
    "            train_preds = model(image)\n",
    "            loss_train = loss_fn(train_preds, label).to(device)\n",
    "            train_loss.append(loss_train.item())\n",
    "    \n",
    "            optimizer.zero_grad()\n",
    "            loss_train.backward()\n",
    "            optimizer.step()\n",
    "            train_bar.set_description(\"Lr:{:.6f},Loss:{:.4f}\".format(optimizer.param_groups[0]['lr'], np.mean(train_loss)))\n",
    "    \n",
    "        torch.cuda.empty_cache()\n",
    "        model.eval()\n",
    "        val_loss = []\n",
    "        total = 0\n",
    "        correct = 0\n",
    "        with torch.no_grad():\n",
    "            for index, (image, label) in tqdm(enumerate(val_dataloader)):\n",
    "                image = image.to(device)\n",
    "                label = label.to(device)\n",
    "                model = model.to(device)\n",
    "    \n",
    "                val_preds = model(image)\n",
    "                loss_val = loss_fn(val_preds, label).to(device)\n",
    "                val_loss.append(loss_val.item())\n",
    "                _, predicted = torch.max(val_preds.data, 1)\n",
    "                total += label.size(0)\n",
    "                correct += (predicted == label).sum().item()\n",
    "        if save_best:\n",
    "            if np.mean(val_loss) < min_loss:\n",
    "                min_loss = np.mean(val_loss)\n",
    "                torch.save(model.state_dict(), \"./%s/%s.pt\" % (weight_dir, model_name))\n",
    "    \n",
    "        print(\"\\nepoch:{:d}/{:d}, Lr:{:.6f},Loss:{:.4f},val_loss:{:.4f},val_accuracy;{:.4f}\".\n",
    "              format(epoch, n_epochs, optimizer.param_groups[0]['lr'], np.mean(train_loss), np.mean(val_loss),\n",
    "                     (100 * correct / total)))\n",
    "    \n",
    "        train_history.append(\n",
    "            [optimizer.param_groups[0]['lr'], np.mean(train_loss), np.mean(val_loss), train_rmse, (100 * correct / total)])\n",
    "        x = pd.DataFrame(train_history)\n",
    "        x.columns = ['lr', 'train_loss', 'val_loss', 'train_accuracy', 'val_accuracy']\n",
    "        x.to_csv(\"./%s/%s_history.csv\" % (results_dir, model_name), index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 保持训练记录到csv文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "train_history = np.random.random(1000,6)\n",
    "x = pd.DataFrame(train_history)\n",
    "x.columns= ['Loss', 'val_loss', 'val_pixAcc', 'train_pixAcc', 'val_mIoU', 'train_mIoU']\n",
    "x.to_csv(\"./checkpoint/%s.csv\"%model_name, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## 显示gpu\n",
    "!pip install gputil\n",
    "from GPUtil import showUtilization as gpu_usage\n",
    "gpu_usage()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型验证过程"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_evaluate(model, val_dataloader, loss_fn):\n",
    "    from tqdm import tqdm\n",
    "    with torch.no_grad():\n",
    "        val_preds = []\n",
    "        val_loss = [ ]\n",
    "        total = 0\n",
    "        correct = 0\n",
    "        for index,(image,label) in tqdm(enumerate(val_dataloader)):\n",
    "            image = image.to(device)\n",
    "            label = label.to(device)\n",
    "            model = model.to(device)\n",
    "            \n",
    "            preds = model(image)\n",
    "            loss_val = loss_fn(preds, label).to(device)\n",
    "            \n",
    "            val_preds.append(preds.cpu().data)\n",
    "            val_loss.append(loss_val.item())\n",
    "\n",
    "            _, predicted = torch.max(preds.data, 1)\n",
    "            total += label.size(0)\n",
    "            correct += (predicted == label).sum().item()\n",
    "        \n",
    "        accuracy = correct / total\n",
    "    \n",
    "    return val_preds, val_loss, accuracy\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_predict(model, test_dataloader):\n",
    "    from tqdm import tqdm\n",
    "    with torch.no_grad():\n",
    "        val_preds = []\n",
    "        for index,image in tqdm(enumerate(test_dataloader)):\n",
    "            image = image.to(device)\n",
    "            model = model.to(device)\n",
    "            preds = model(image)\n",
    "            val_preds.append(preds.cpu().data)\n",
    "    \n",
    "    return val_preds\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "toc-hr-collapsed": true
   },
   "source": [
    "# 高级技巧"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GEM\n",
    "generalized mean pooling，出自Fine-tuning CNN Image Retrieval with No Human Annotation，提出的是一种可学习的pooling layer，可提高检索性能"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn.parameter import Parameter\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "class GeneralizedMeanPooling(nn.Module):\n",
    "    \"\"\"Applies a 2D power-average adaptive pooling over an input signal composed of several input planes.\n",
    "    The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)`\n",
    "        - At p = infinity, one gets Max Pooling\n",
    "        - At p = 1, one gets Average Pooling\n",
    "    The output is of size H x W, for any input size.\n",
    "    The number of output features is equal to the number of input planes.\n",
    "    Args:\n",
    "        output_size: the target output size of the image of the form H x W.\n",
    "                     Can be a tuple (H, W) or a single H for a square image H x H\n",
    "                     H and W can be either a ``int``, or ``None`` which means the size will\n",
    "                     be the same as that of the input.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, norm, output_size=1, eps=1e-6):\n",
    "        super(GeneralizedMeanPooling, self).__init__()\n",
    "        assert norm > 0\n",
    "        self.p = float(norm)\n",
    "        self.output_size = output_size\n",
    "        self.eps = eps\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.clamp(min=self.eps).pow(self.p)\n",
    "        return F.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p)\n",
    "\n",
    "    def __repr__(self):\n",
    "        return self.__class__.__name__ + '(' \\\n",
    "            + str(self.p) + ', ' \\\n",
    "            + 'output_size=' + str(self.output_size) + ')'\n",
    "\n",
    "\n",
    "class GeneralizedMeanPoolingP(GeneralizedMeanPooling):\n",
    "    \"\"\" Same, but norm is trainable\n",
    "    \"\"\"\n",
    "    def __init__(self, norm=3, output_size=1, eps=1e-6):\n",
    "        super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps)\n",
    "        self.p = Parameter(torch.ones(1) * norm)\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    inp = torch.randn(2, 7, 7)\n",
    "    pool = GeneralizedMeanPooling(2)\n",
    "    out = pool(inp) #(2,1,1)\n",
    "    print(out.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## [BNNeck](https://zhuanlan.zhihu.com/p/61831669)\n",
    "\n",
    "出自罗浩博士的Bag of Tricks and A Strong Baseline for Deep Person Re-identification\n",
    "其实就是在feature层和fc layer之间增加一层Batch Normalization layer，然后在retrieval的时候，\n",
    "使用BN后的feature再做一个l2 norm，也就是retrieval with Cosine distance。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class res50(torch.nn.Module):\n",
    "    def __init__(self, num_classes):\n",
    "        super(res50, self).__init__()\n",
    "        resnet = resnet50(pretrained=True)\n",
    "        self.backbone = torch.nn.Sequential(\n",
    "                        resnet.conv1,\n",
    "                        resnet.bn1,\n",
    "                        resnet.relu,\n",
    "                        resnet.layer1,\n",
    "                        resnet.layer2,\n",
    "                        resnet.layer3,\n",
    "                        resnet.layer4\n",
    "        )\n",
    "        self.pool = torch.nn.AdaptiveMaxPool2d(1)\n",
    "        self.bnneck = nn.BatchNorm1d(2048)\n",
    "        self.bnneck.bias.requires_grad_(False)  # no shift\n",
    "        self.classifier = nn.Linear(2048, num_classes, bias=False)\n",
    "    def forward(self, x):\n",
    "        x = self.backbone(x)\n",
    "        x = self.pool(x)\n",
    "        feat = x.view(x.shape[0], -1)\n",
    "        feat = self.bnneck(feat)\n",
    "        if not self.training:\n",
    "            return nn.functional.normalize(feat, dim=1, p=2)\n",
    "        x = self.classifier(feat)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## dual pooling\n",
    "\n",
    "这种是在模型层进行改造的一种小trick了，常见的做法：global max/average pooling + fc layer，这里试concat(global max-pooling, global average pooling) + fc layer，其实就是为了丰富特征层，max pooling更加关注重要的局部特征，而average pooling试更加关注全局的特征。不一定有效，我试过不少次，有效的次数比较少，但不少人喜欢这样用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class res18(nn.Module):\n",
    "    def __init__(self, num_classes):\n",
    "        super(res18, self).__init__()\n",
    "        self.base = resnet18(pretrained=True)\n",
    "        self.feature = nn.Sequential(\n",
    "            self.base.conv1,\n",
    "            self.base.bn1,\n",
    "            self.base.relu,\n",
    "            self.base.maxpool,\n",
    "            self.base.layer1,\n",
    "            self.base.layer2,\n",
    "            self.base.layer3,\n",
    "            self.base.layer4\n",
    "        )\n",
    "        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n",
    "        self.max_pool = nn.AdaptiveMaxPool2d(1)\n",
    "        self.reduce_layer = nn.Conv2d(1024, 512, 1)\n",
    "        self.fc  = nn.Sequential(\n",
    "            nn.Dropout(0.5),\n",
    "            nn.Linear(512, num_classes)\n",
    "            )\n",
    "    def forward(self, x):\n",
    "        bs = x.shape[0]\n",
    "        x = self.feature(x)\n",
    "        x1 = self.avg_pool(x).view(bs, -1)\n",
    "        x2 = self.max_pool(x).view(bs, -1)\n",
    "        x1 = self.avg_pool(x)\n",
    "        x2 = self.max_pool(x)\n",
    "        x = torch.cat([x1, x2], dim=1)\n",
    "        x = self.reduce_layer(x).view(bs, -1)\n",
    "        logits = self.fc(x)\n",
    "        return logits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## unet的改进\n",
    "\n",
    "很多top solution都是修改Unet的Decoder，最常见的有如下几种。\n",
    "* [oc block in deconder](https://github.com/tugstugi/pytorch-saltnet)\n",
    "\n",
    "* scse block\n",
    "\n",
    "* Hypercolumn block\n",
    "\n",
    "* CBAM（Convolutional Block Attention Module，bestfitting比较喜欢用）\n",
    "\n",
    "* BAM（Bottleneck attention module）\n",
    "\n",
    "这些注意力block一般是放在decoder不同stage出来的feature后面，因为注意力机制往往都是来优化feature的。\n",
    "\n",
    "* dual head(multi task learning)，也就是构造一个end2end带有分割与分类的模型。\n",
    "\n",
    "同时，多任务学习往往会降低模型过拟合的程度，并可以提升模型的性能。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  },
  "toc-autonumbering": true
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
